refactor of key-switching & external product

This commit is contained in:
Jean-Philippe Bossuat
2025-05-15 18:24:56 +02:00
parent 723a41acd0
commit ccd7450c5f
15 changed files with 1593 additions and 1740 deletions

View File

@@ -1,8 +1,7 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef,
VecZnxOps, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
@@ -13,7 +12,6 @@ use crate::{
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GGLWECiphertext<C, B: Backend> {
@@ -212,81 +210,3 @@ where
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
grlwe_size: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
}
}
impl<C> VecGLWEProduct for GGLWECiphertext<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn prod_with_glwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<R>,
a: &GLWECiphertext<A>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(a.rank(), self.rank_in());
assert_eq!(res.rank(), self.rank_out());
assert_eq!(res.basek(), basek);
assert_eq!(a.basek(), basek);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
assert!(
scratch.available()
>= GGLWECiphertext::prod_with_glwe_scratch_space(
module,
res.size(),
a.size(),
self.size(),
self.rank_in(),
self.rank_out()
)
);
}
let cols_in: usize = self.rank_in();
let cols_out: usize = self.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
{
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
});
}
}

View File

@@ -1,35 +1,31 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GGSWCiphertext<C, B: Backend> {
pub data: MatZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
pub basek: usize,
pub k: usize,
}
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self {
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
basek: basek,
k: k,
}
}
}
@@ -42,11 +38,11 @@ impl<T, B: Backend> Infos for GGSWCiphertext<T, B> {
}
fn basek(&self) -> usize {
self.log_base2k
self.basek
}
fn k(&self) -> usize {
self.log_k
self.k
}
}
@@ -82,35 +78,28 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub fn keyswitch_scratch_space(
pub fn external_product_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
tmp_in + tmp_out + ggsw
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
tmp + ggsw
}
}
@@ -140,7 +129,7 @@ where
}
let size: usize = self.size();
let log_base2k: usize = self.basek();
let basek: usize = self.basek();
let k: usize = self.k();
let cols: usize = self.rank() + 1;
@@ -149,20 +138,20 @@ where
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
basek: log_base2k,
basek: basek,
k: k,
};
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_znx_ct,
basek: log_base2k,
basek: basek,
k,
};
(0..self.rows()).for_each(|row_j| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2);
(0..cols).for_each(|col_i| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct
@@ -193,30 +182,6 @@ where
});
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_vec_glwe(module, self, lhs, scratch);
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_vec_glwe_inplace(module, self, scratch);
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
@@ -227,7 +192,55 @@ where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_vec_glwe(module, self, lhs, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_in rank: {} != ggsw_apply rank: {}",
self.rank(),
rhs.rank()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank() + 1).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn external_product_inplace<DataRhs>(
@@ -238,7 +251,32 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_vec_glwe_inplace(module, self, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_out rank: {} != ggsw_apply: {}",
self.rank(),
rhs.rank()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.external_product_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
}
@@ -270,73 +308,3 @@ where
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl VecGLWEProductScratchSpace for GGSWCiphertext<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
rgsw_size: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size)
+ module.vmp_apply_tmp_bytes(
res_size,
a_size,
a_size,
rank_in + 1,
rank_out + 1,
rgsw_size,
))
| module.vec_znx_big_normalize_tmp_bytes())
}
}
impl<C> VecGLWEProduct for GGSWCiphertext<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn prod_with_glwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<R>,
a: &GLWECiphertext<A>,
scratch: &mut Scratch,
) where
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
let log_base2k: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), a.rank());
assert_eq!(self.rank(), res.rank());
assert_eq!(res.basek(), log_base2k);
assert_eq!(a.basek(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
}
let cols: usize = self.rank() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size());
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, a, col_i);
});
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1);
});
}
}

View File

@@ -1,22 +1,20 @@
use base2k::{
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
use crate::{
SIX_SIGMA,
elem::Infos,
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GLWECiphertext<C> {
@@ -115,33 +113,50 @@ impl GLWECiphertext<Vec<u8>> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size,
in_rank + 1,
out_rank + 1,
ksk_size,
) + module.bytes_of_vec_znx_dft(in_size, in_size)))
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, ggsw_size)
+ ((module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size, // rows
rank + 1, // cols in
rank + 1, // cols out
ggsw_size,
))
| module.vec_znx_big_normalize_tmp_bytes())
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank)
}
}
@@ -235,7 +250,50 @@ where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertext::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch_inplace<DataRhs>(
@@ -246,7 +304,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
@@ -259,7 +320,36 @@ where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
});
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn external_product_inplace<DataRhs>(
@@ -270,7 +360,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(

View File

@@ -1,20 +1,13 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::Infos,
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
};
pub struct GLWECiphertextFourier<C, B: Backend> {
@@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
}
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)),
basek: log_base2k,
k: log_k,
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
basek: basek,
k: k,
}
}
}
@@ -92,33 +85,56 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size,
in_rank + 1,
out_rank + 1,
ksk_size,
);
let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size);
let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes();
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | add_a0 | (res_small + normalize))
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | (res_small + normalize))
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
}
}
@@ -158,7 +174,61 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
// Space fr normalized VMP result outside of DFT domain
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
});
}
pub fn keyswitch_inplace<DataRhs>(
@@ -169,7 +239,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
@@ -182,7 +255,37 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_fourier(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
// Space for VMP result in DFT domain and high precision
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
}
// VMP result in high precision
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
// Space for VMP result normalized
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
});
}
pub fn external_product_inplace<DataRhs>(
@@ -193,7 +296,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_fourier_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
}
@@ -247,6 +353,7 @@ where
pt.k = pt.k().min(self.k());
}
#[allow(dead_code)]
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
where
GLWECiphertext<DataRes>: VecZnxToMut,

View File

@@ -43,10 +43,10 @@ where
}
impl GLWEPlaintext<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, base2k: usize, k: usize) -> Self {
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(base2k, k)),
basek: base2k,
data: module.new_vec_znx(1, derive_size(basek, k)),
basek: basek,
k,
}
}

View File

@@ -1,6 +1,6 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
};
use sampling::source::Source;
@@ -10,7 +10,6 @@ use crate::{
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
@@ -39,6 +38,20 @@ impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
}
}
impl<T, B: Backend> GLWESwitchingKey<T, B> {
pub fn rank(&self) -> usize {
self.0.data.cols_out() - 1
}
pub fn rank_in(&self) -> usize {
self.0.data.cols_in()
}
pub fn rank_out(&self) -> usize {
self.0.data.cols_out() - 1
}
}
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
where
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
@@ -131,33 +144,46 @@ where
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size);
tmp_in + tmp_out + ksk
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size);
tmp + ksk
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
tmp_in + tmp_out + ggsw
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
tmp + ggsw
}
}
@@ -175,8 +201,62 @@ where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0
.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.keyswitch(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn keyswitch_inplace<DataRhs>(
@@ -187,8 +267,32 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0
.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.keyswitch_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
pub fn external_product<DataLhs, DataRhs>(
@@ -201,7 +305,62 @@ where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank(),
"ksk_in output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn external_product_inplace<DataRhs>(
@@ -212,6 +371,31 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.external_product_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
}

View File

@@ -9,6 +9,5 @@ pub mod keyswitch_key;
#[cfg(test)]
mod test_fft64;
mod utils;
pub mod vec_glwe_product;
pub(crate) const SIX_SIGMA: f64 = 6.0;

File diff suppressed because it is too large Load Diff

View File

@@ -1,575 +1,345 @@
// use base2k::{
// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
// };
// use sampling::source::Source;
//
// use crate::{
// elem::{GetRow, Infos},
// ggsw_ciphertext::GGSWCiphertext,
// glwe_ciphertext_fourier::GLWECiphertextFourier,
// glwe_plaintext::GLWEPlaintext,
// keys::{SecretKey, SecretKeyFourier},
// keyswitch_key::GLWESwitchingKey,
// test_fft64::gglwe::noise_grlwe_rlwe_product,
// };
//
// #[test]
// fn encrypt_sk() {
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
// let log_base2k: usize = 8;
// let log_k_ct: usize = 54;
// let rows: usize = 4;
// let rank: usize = 1;
//
// let sigma: f64 = 3.2;
//
// let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank);
// let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_ct);
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_ct);
// let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
// );
//
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk_dft.dft(&module, &sk);
//
// ct.encrypt_sk(
// &module,
// &pt_scalar,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
//
// (0..ct.rank()).for_each(|col_j| {
// (0..ct.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
//
// if col_j == 1 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
//
// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
//
// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
//
// pt_want.data.zero();
// });
// });
// }
//
// #[test]
// fn keyswitch() {
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
// let log_base2k: usize = 12;
// let log_k_grlwe: usize = 60;
// let log_k_rgsw_in: usize = 45;
// let log_k_rgsw_out: usize = 45;
// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k;
//
// let rank: usize = 1;
//
// let sigma: f64 = 3.2;
//
// let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> =
// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank);
// let mut ct_rgsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank);
// let mut ct_rgsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank);
// let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// Random input plaintext
// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size())
// | GGSWCiphertext::keyswitch_scratch_space(
// &module,
// ct_rgsw_out.size(),
// ct_rgsw_in.size(),
// ct_grlwe.size(),
// ),
// );
//
// let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk0.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk0_dft.dft(&module, &sk0);
//
// let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk1.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk1_dft.dft(&module, &sk1);
//
// ct_grlwe.encrypt_sk(
// &module,
// &sk0.data,
// &sk1_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_in.encrypt_sk(
// &module,
// &pt_rgsw,
// &sk0_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow());
//
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out);
//
// (0..ct_rgsw_out.rank()).for_each(|col_j| {
// (0..ct_rgsw_out.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0);
//
// if col_j == 1 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
// let noise_want: f64 = noise_grlwe_rlwe_product(
// module.n() as f64,
// log_base2k,
// 0.5,
// 0.5,
// 0f64,
// sigma * sigma,
// 0f64,
// log_k_grlwe,
// log_k_grlwe,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.2,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
//
// #[test]
// fn keyswitch_inplace() {
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
// let log_base2k: usize = 12;
// let log_k_grlwe: usize = 60;
// let log_k_rgsw: usize = 45;
// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k;
// let rank: usize = 1;
//
// let sigma: f64 = 3.2;
//
// let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> =
// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank);
// let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank);
// let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// Random input plaintext
// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()),
// );
//
// let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk0.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk0_dft.dft(&module, &sk0);
//
// let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk1.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk1_dft.dft(&module, &sk1);
//
// ct_grlwe.encrypt_sk(
// &module,
// &sk0.data,
// &sk1_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw.encrypt_sk(
// &module,
// &pt_rgsw,
// &sk0_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow());
//
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw);
//
// (0..ct_rgsw.rank()).for_each(|col_j| {
// (0..ct_rgsw.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0);
//
// if col_j == 1 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
// let noise_want: f64 = noise_grlwe_rlwe_product(
// module.n() as f64,
// log_base2k,
// 0.5,
// 0.5,
// 0f64,
// sigma * sigma,
// 0f64,
// log_k_grlwe,
// log_k_grlwe,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.2,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
//
// #[test]
// fn external_product() {
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
// let log_base2k: usize = 12;
// let log_k_rgsw_rhs: usize = 60;
// let log_k_rgsw_lhs_in: usize = 45;
// let log_k_rgsw_lhs_out: usize = 45;
// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k;
// let rank: usize = 1;
//
// let sigma: f64 = 3.2;
//
// let mut ct_rgsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank);
// let mut ct_rgsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> =
// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank);
// let mut ct_rgsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> =
// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank);
// let mut pt_rgsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
// let mut pt_rgsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// Random input plaintext
// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let k: usize = 1;
//
// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size())
// | GGSWCiphertext::external_product_scratch_space(
// &module,
// ct_rgsw_lhs_out.size(),
// ct_rgsw_lhs_in.size(),
// ct_rgsw_rhs.size(),
// ),
// );
//
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk_dft.dft(&module, &sk);
//
// ct_rgsw_rhs.encrypt_sk(
// &module,
// &pt_rgsw_rhs,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_lhs_in.encrypt_sk(
// &module,
// &pt_rgsw_lhs,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow());
//
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out);
//
// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0);
//
// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| {
// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0);
//
// if col_j == 1 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
//
// let var_gct_err_lhs: f64 = sigma * sigma;
// let var_gct_err_rhs: f64 = 0f64;
//
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
// let var_a0_err: f64 = sigma * sigma;
// let var_a1_err: f64 = 1f64 / 12f64;
//
// let noise_want: f64 = noise_rgsw_product(
// module.n() as f64,
// log_base2k,
// 0.5,
// var_msg,
// var_a0_err,
// var_a1_err,
// var_gct_err_lhs,
// var_gct_err_rhs,
// log_k_rgsw_lhs_in,
// log_k_rgsw_rhs,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.1,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
//
// #[test]
// fn external_product_inplace() {
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
// let log_base2k: usize = 12;
// let log_k_rgsw_rhs: usize = 60;
// let log_k_rgsw_lhs: usize = 45;
// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k;
// let rank: usize = 1;
//
// let sigma: f64 = 3.2;
//
// let mut ct_rgsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank);
// let mut ct_rgsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank);
// let mut pt_rgsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
// let mut pt_rgsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// Random input plaintext
// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let k: usize = 1;
//
// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size())
// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()),
// );
//
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk_dft.dft(&module, &sk);
//
// ct_rgsw_rhs.encrypt_sk(
// &module,
// &pt_rgsw_rhs,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_lhs.encrypt_sk(
// &module,
// &pt_rgsw_lhs,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow());
//
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs);
//
// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0);
//
// (0..ct_rgsw_lhs.rank()).for_each(|col_j| {
// (0..ct_rgsw_lhs.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0);
//
// if col_j == 1 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
//
// let var_gct_err_lhs: f64 = sigma * sigma;
// let var_gct_err_rhs: f64 = 0f64;
//
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
// let var_a0_err: f64 = sigma * sigma;
// let var_a1_err: f64 = 1f64 / 12f64;
//
// let noise_want: f64 = noise_rgsw_product(
// module.n() as f64,
// log_base2k,
// 0.5,
// var_msg,
// var_a0_err,
// var_a1_err,
// var_gct_err_lhs,
// var_gct_err_rhs,
// log_k_rgsw_lhs,
// log_k_rgsw_rhs,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.1,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
pub(crate) fn noise_ggsw_gglwe_product(
use base2k::{
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos},
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(11, 8, 54, 3.2, rank);
});
}
#[test]
fn external_product() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product(12, 12, 60, rank, 3.2);
});
}
#[test]
fn external_product_inplace() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product_inplace(12, 15, 60, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
(0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
pt_want.data.zero();
});
});
}
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
let k: usize = 1;
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
| GGSWCiphertext::external_product_scratch_space(
&module,
ct_ggsw_lhs_out.size(),
ct_ggsw_lhs_in.size(),
ct_ggsw_rhs.size(),
rank,
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_ggsw_rhs.encrypt_sk(
&module,
&pt_ggsw_rhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs_in.encrypt_sk(
&module,
&pt_ggsw_lhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
(0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ggsw,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"have: {} want: {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
let k: usize = 1;
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size())
| GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_ggsw_rhs.encrypt_sk(
&module,
&pt_ggsw_rhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs.encrypt_sk(
&module,
&pt_ggsw_lhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size());
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
(0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ggsw,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"have: {} want: {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
pub(crate) fn noise_ggsw_product(
n: f64,
log_base2k: usize,
basek: usize,
var_xs: f64,
var_msg: f64,
var_a0_err: f64,
@@ -581,12 +351,12 @@ pub(crate) fn noise_ggsw_gglwe_product(
b_logq: usize,
) -> f64 {
let a_logq: usize = a_logq.min(b_logq);
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
let a_cols: usize = (a_logq + basek - 1) / basek;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (log_base2k)) as f64;
let base: f64 = (1 << (basek)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)

View File

@@ -13,7 +13,7 @@ use crate::{
glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product},
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product},
};
#[test]
@@ -498,7 +498,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_gglwe_product(
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
@@ -595,7 +595,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_gglwe_product(
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,

View File

@@ -6,7 +6,7 @@ use crate::{
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product},
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product},
};
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut};
use sampling::source::Source;
@@ -322,7 +322,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_gglwe_product(
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
@@ -422,7 +422,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_gglwe_product(
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,

View File

@@ -1,3 +1,3 @@
pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize {
(log_k + log_base2k - 1) / log_base2k
pub(crate) fn derive_size(basek: usize, k: usize) -> usize {
(k + basek - 1) / basek
}

View File

@@ -1,218 +0,0 @@
use base2k::{
FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use crate::{
elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
};
pub(crate) trait VecGLWEProductScratchSpace {
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
) -> usize;
fn prod_with_glwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank)
}
fn prod_with_glwe_fourier_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
(Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(rank_in + 1, lhs)
+ module.bytes_of_vec_znx(rank_out + 1, res_size)
}
fn prod_with_glwe_fourier_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
(Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(rank + 1, res_size)
}
fn prod_with_vec_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
Self::prod_with_glwe_fourier_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out)
+ module.bytes_of_vec_znx_dft(rank_in + 1, lhs)
+ module.bytes_of_vec_znx_dft(rank_out + 1, res_size)
}
fn prod_with_vec_glwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
Self::prod_with_glwe_fourier_inplace_scratch_space(module, res_size, rhs, rank)
+ module.bytes_of_vec_znx_dft(rank + 1, res_size)
}
}
pub(crate) trait VecGLWEProduct: Infos {
fn prod_with_glwe<MUT, REF>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<MUT>,
a: &GLWECiphertext<REF>,
scratch: &mut Scratch,
) where
VecZnx<MUT>: VecZnxToMut,
VecZnx<REF>: VecZnxToRef;
fn prod_with_glwe_inplace<MUT>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<MUT>, scratch: &mut Scratch)
where
VecZnx<MUT>: VecZnxToMut + VecZnxToRef,
{
unsafe {
let res_ptr: *mut GLWECiphertext<MUT> = res as *mut GLWECiphertext<MUT>; // This is ok because [Self::mul_rlwe] only updates res at the end.
self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch);
}
}
fn prod_with_glwe_fourier<MUT, REF>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertextFourier<MUT, FFT64>,
a: &GLWECiphertextFourier<REF, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<MUT, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<REF, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
let log_base2k: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(res.basek(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
}
let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size());
let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: a_data,
basek: a.basek(),
k: a.k(),
};
a.idft(module, &mut a_idft, scratch_1);
let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_data,
basek: res.basek(),
k: res.k(),
};
self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2);
res_idft.dft(module, res);
}
fn prod_with_glwe_fourier_inplace<MUT>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertextFourier<MUT, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<MUT, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
let log_base2k: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(res.basek(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
}
let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_data,
basek: res.basek(),
k: res.k(),
};
res.idft(module, &mut res_idft, scratch_1);
self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1);
res_idft.dft(module, res);
}
fn prod_with_vec_glwe<RES, LHS>(&self, module: &Module<FFT64>, res: &mut RES, a: &LHS, scratch: &mut Scratch)
where
LHS: GetRow<FFT64> + Infos,
RES: SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size());
let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_row_data,
basek: a.basek(),
k: a.k(),
};
let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size());
let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_res_data,
basek: res.basek(),
k: res.k(),
};
let min_rows: usize = res.rows().min(a.rows());
(0..res.rows()).for_each(|row_i| {
(0..res.cols()).for_each(|col_j| {
a.get_row(module, row_i, col_j, &mut tmp_a_row);
self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2);
res.set_row(module, row_i, col_j, &tmp_res_row);
});
});
tmp_res_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
res.set_row(module, row_i, col_j, &tmp_res_row);
});
});
}
fn prod_with_vec_glwe_inplace<RES>(&self, module: &Module<FFT64>, res: &mut RES, scratch: &mut Scratch)
where
RES: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size());
let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_row_data,
basek: res.basek(),
k: res.k(),
};
(0..res.rows()).for_each(|row_i| {
(0..res.cols()).for_each(|col_j| {
res.get_row(module, row_i, col_j, &mut tmp_row);
self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
});
});
}
}