mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -42,8 +42,13 @@ pub trait VecZnxDftOps<B: Backend> {
|
|||||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
|
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<B>,
|
||||||
|
A: VecZnxDftToRef<B>;
|
||||||
|
|
||||||
/// b <- IDFT(a), uses a as scratch space.
|
/// b <- IDFT(a), uses a as scratch space.
|
||||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize)
|
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<B>,
|
R: VecZnxBigToMut<B>,
|
||||||
A: VecZnxDftToMut<B>;
|
A: VecZnxDftToMut<B>;
|
||||||
@@ -79,13 +84,33 @@ impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
|
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<FFT64>,
|
||||||
|
A: VecZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
|
|
||||||
|
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||||
|
|
||||||
|
(0..min_size).for_each(|j| {
|
||||||
|
res_mut
|
||||||
|
.at_mut(res_col, j)
|
||||||
|
.copy_from_slice(a_ref.at(a_col, j));
|
||||||
|
});
|
||||||
|
(min_size..res_mut.size()).for_each(|j| {
|
||||||
|
res_mut.zero_at(res_col, j);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
A: VecZnxDftToMut<FFT64>,
|
A: VecZnxDftToMut<FFT64>,
|
||||||
{
|
{
|
||||||
let mut res_mut = res.to_mut();
|
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
let mut a_mut = a.to_mut();
|
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||||
|
|
||||||
let min_size: usize = min(res_mut.size(), a_mut.size());
|
let min_size: usize = min(res_mut.size(), a_mut.size());
|
||||||
|
|
||||||
@@ -136,14 +161,14 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
/// b <- DFT(a)
|
/// b <- DFT(a)
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
/// If b.cols < a_cols
|
/// If b.cols < a_col
|
||||||
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<FFT64>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let mut res_mut = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref = a.to_ref();
|
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
|
||||||
|
|
||||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||||
|
|
||||||
@@ -170,8 +195,8 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
let mut res_mut = res.to_mut();
|
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef,
|
||||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
VecZnxOps, ZnxInfos, ZnxZero,
|
||||||
ZnxZero,
|
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -13,7 +12,6 @@ use crate::{
|
|||||||
glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
keys::SecretKeyFourier,
|
keys::SecretKeyFourier,
|
||||||
utils::derive_size,
|
utils::derive_size,
|
||||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct GGLWECiphertext<C, B: Backend> {
|
pub struct GGLWECiphertext<C, B: Backend> {
|
||||||
@@ -212,81 +210,3 @@ where
|
|||||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
|
|
||||||
fn prod_with_glwe_scratch_space(
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
grlwe_size: usize,
|
|
||||||
rank_in: usize,
|
|
||||||
rank_out: usize,
|
|
||||||
) -> usize {
|
|
||||||
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
|
|
||||||
+ (module.vec_znx_big_normalize_tmp_bytes()
|
|
||||||
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
|
|
||||||
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> VecGLWEProduct for GGLWECiphertext<C, FFT64>
|
|
||||||
where
|
|
||||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
|
||||||
{
|
|
||||||
fn prod_with_glwe<R, A>(
|
|
||||||
&self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res: &mut GLWECiphertext<R>,
|
|
||||||
a: &GLWECiphertext<A>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
|
||||||
VecZnx<R>: VecZnxToMut,
|
|
||||||
VecZnx<A>: VecZnxToRef,
|
|
||||||
{
|
|
||||||
let basek: usize = self.basek();
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.rank(), self.rank_in());
|
|
||||||
assert_eq!(res.rank(), self.rank_out());
|
|
||||||
assert_eq!(res.basek(), basek);
|
|
||||||
assert_eq!(a.basek(), basek);
|
|
||||||
assert_eq!(self.n(), module.n());
|
|
||||||
assert_eq!(res.n(), module.n());
|
|
||||||
assert_eq!(a.n(), module.n());
|
|
||||||
assert!(
|
|
||||||
scratch.available()
|
|
||||||
>= GGLWECiphertext::prod_with_glwe_scratch_space(
|
|
||||||
module,
|
|
||||||
res.size(),
|
|
||||||
a.size(),
|
|
||||||
self.size(),
|
|
||||||
self.rank_in(),
|
|
||||||
self.rank_out()
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let cols_in: usize = self.rank_in();
|
|
||||||
let cols_out: usize = self.rank_out() + 1;
|
|
||||||
|
|
||||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
|
|
||||||
|
|
||||||
{
|
|
||||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
|
|
||||||
(0..cols_in).for_each(|col_i| {
|
|
||||||
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
|
|
||||||
});
|
|
||||||
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
|
||||||
|
|
||||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
|
||||||
|
|
||||||
(0..cols_out).for_each(|i| {
|
|
||||||
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,35 +1,31 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
|
||||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
|
||||||
ZnxZero,
|
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elem::{GetRow, Infos, SetRow},
|
elem::{GetRow, Infos, SetRow},
|
||||||
gglwe_ciphertext::GGLWECiphertext,
|
|
||||||
glwe_ciphertext::GLWECiphertext,
|
glwe_ciphertext::GLWECiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
keys::SecretKeyFourier,
|
keys::SecretKeyFourier,
|
||||||
keyswitch_key::GLWESwitchingKey,
|
|
||||||
utils::derive_size,
|
utils::derive_size,
|
||||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct GGSWCiphertext<C, B: Backend> {
|
pub struct GGSWCiphertext<C, B: Backend> {
|
||||||
pub data: MatZnxDft<C, B>,
|
pub data: MatZnxDft<C, B>,
|
||||||
pub log_base2k: usize,
|
pub basek: usize,
|
||||||
pub log_k: usize,
|
pub k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self {
|
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)),
|
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
|
||||||
log_base2k: log_base2k,
|
basek: basek,
|
||||||
log_k: log_k,
|
k: k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -42,11 +38,11 @@ impl<T, B: Backend> Infos for GGSWCiphertext<T, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn basek(&self) -> usize {
|
fn basek(&self) -> usize {
|
||||||
self.log_base2k
|
self.basek
|
||||||
}
|
}
|
||||||
|
|
||||||
fn k(&self) -> usize {
|
fn k(&self) -> usize {
|
||||||
self.log_k
|
self.k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,35 +78,28 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
|||||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_scratch_space(
|
pub fn external_product_scratch_space(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res_size: usize,
|
out_size: usize,
|
||||||
lhs: usize,
|
in_size: usize,
|
||||||
rhs: usize,
|
ggsw_size: usize,
|
||||||
rank_in: usize,
|
rank: usize,
|
||||||
rank_out: usize,
|
|
||||||
) -> usize {
|
) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||||
)
|
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||||
|
tmp_in + tmp_out + ggsw
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_inplace_scratch_space(
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, rhs, rank,
|
out_size: usize,
|
||||||
)
|
ggsw_size: usize,
|
||||||
}
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||||
module, res_size, lhs, rhs, rank, rank,
|
tmp + ggsw
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
|
||||||
module, res_size, rhs, rank,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +129,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let size: usize = self.size();
|
let size: usize = self.size();
|
||||||
let log_base2k: usize = self.basek();
|
let basek: usize = self.basek();
|
||||||
let k: usize = self.k();
|
let k: usize = self.k();
|
||||||
let cols: usize = self.rank() + 1;
|
let cols: usize = self.rank() + 1;
|
||||||
|
|
||||||
@@ -149,20 +138,20 @@ where
|
|||||||
|
|
||||||
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
||||||
data: tmp_znx_pt,
|
data: tmp_znx_pt,
|
||||||
basek: log_base2k,
|
basek: basek,
|
||||||
k: k,
|
k: k,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
|
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
|
||||||
data: tmp_znx_ct,
|
data: tmp_znx_ct,
|
||||||
basek: log_base2k,
|
basek: basek,
|
||||||
k,
|
k,
|
||||||
};
|
};
|
||||||
|
|
||||||
(0..self.rows()).for_each(|row_j| {
|
(0..self.rows()).for_each(|row_j| {
|
||||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
||||||
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
|
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2);
|
||||||
|
|
||||||
(0..cols).for_each(|col_i| {
|
(0..cols).for_each(|col_i| {
|
||||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||||
@@ -193,30 +182,6 @@ where
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch<DataLhs, DataRhs>(
|
|
||||||
&mut self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
|
||||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
|
||||||
{
|
|
||||||
rhs.0.prod_with_vec_glwe(module, self, lhs, scratch);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn keyswitch_inplace<DataRhs>(
|
|
||||||
&mut self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
|
||||||
{
|
|
||||||
rhs.0.prod_with_vec_glwe_inplace(module, self, scratch);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn external_product<DataLhs, DataRhs>(
|
pub fn external_product<DataLhs, DataRhs>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
@@ -227,7 +192,55 @@ where
|
|||||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_vec_glwe(module, self, lhs, scratch);
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank(),
|
||||||
|
lhs.rank(),
|
||||||
|
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||||
|
self.rank(),
|
||||||
|
lhs.rank()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ggsw_in rank: {} != ggsw_apply rank: {}",
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size());
|
||||||
|
|
||||||
|
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_in_data,
|
||||||
|
basek: lhs.basek(),
|
||||||
|
k: lhs.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_out_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank() + 1).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||||
|
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
tmp_out.data.zero();
|
||||||
|
|
||||||
|
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||||
|
(0..self.rank() + 1).for_each(|col_j| {
|
||||||
|
self.set_row(module, row_i, col_j, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace<DataRhs>(
|
pub fn external_product_inplace<DataRhs>(
|
||||||
@@ -238,7 +251,32 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_vec_glwe_inplace(module, self, scratch);
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ggsw_out rank: {} != ggsw_apply: {}",
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank() + 1).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
self.get_row(module, row_j, col_i, &mut tmp);
|
||||||
|
tmp.external_product_inplace(module, rhs, scratch1);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,73 +308,3 @@ where
|
|||||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecGLWEProductScratchSpace for GGSWCiphertext<Vec<u8>, FFT64> {
|
|
||||||
fn prod_with_glwe_scratch_space(
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
rgsw_size: usize,
|
|
||||||
rank_in: usize,
|
|
||||||
rank_out: usize,
|
|
||||||
) -> usize {
|
|
||||||
module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size)
|
|
||||||
+ ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size)
|
|
||||||
+ module.vmp_apply_tmp_bytes(
|
|
||||||
res_size,
|
|
||||||
a_size,
|
|
||||||
a_size,
|
|
||||||
rank_in + 1,
|
|
||||||
rank_out + 1,
|
|
||||||
rgsw_size,
|
|
||||||
))
|
|
||||||
| module.vec_znx_big_normalize_tmp_bytes())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> VecGLWEProduct for GGSWCiphertext<C, FFT64>
|
|
||||||
where
|
|
||||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
|
||||||
{
|
|
||||||
fn prod_with_glwe<R, A>(
|
|
||||||
&self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res: &mut GLWECiphertext<R>,
|
|
||||||
a: &GLWECiphertext<A>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
VecZnx<R>: VecZnxToMut,
|
|
||||||
VecZnx<A>: VecZnxToRef,
|
|
||||||
{
|
|
||||||
let log_base2k: usize = self.basek();
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(self.rank(), a.rank());
|
|
||||||
assert_eq!(self.rank(), res.rank());
|
|
||||||
assert_eq!(res.basek(), log_base2k);
|
|
||||||
assert_eq!(a.basek(), log_base2k);
|
|
||||||
assert_eq!(self.n(), module.n());
|
|
||||||
assert_eq!(res.n(), module.n());
|
|
||||||
assert_eq!(a.n(), module.n());
|
|
||||||
}
|
|
||||||
|
|
||||||
let cols: usize = self.rank() + 1;
|
|
||||||
|
|
||||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise
|
|
||||||
|
|
||||||
{
|
|
||||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size());
|
|
||||||
(0..cols).for_each(|col_i| {
|
|
||||||
module.vec_znx_dft(&mut a_dft, col_i, a, col_i);
|
|
||||||
});
|
|
||||||
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
|
|
||||||
}
|
|
||||||
|
|
||||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
|
||||||
|
|
||||||
(0..cols).for_each(|i| {
|
|
||||||
module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,22 +1,20 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
|
||||||
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
|
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
|
||||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||||
ZnxZero,
|
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
SIX_SIGMA,
|
SIX_SIGMA,
|
||||||
elem::Infos,
|
elem::Infos,
|
||||||
gglwe_ciphertext::GGLWECiphertext,
|
|
||||||
ggsw_ciphertext::GGSWCiphertext,
|
ggsw_ciphertext::GGSWCiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
|
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
|
||||||
keyswitch_key::GLWESwitchingKey,
|
keyswitch_key::GLWESwitchingKey,
|
||||||
utils::derive_size,
|
utils::derive_size,
|
||||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct GLWECiphertext<C> {
|
pub struct GLWECiphertext<C> {
|
||||||
@@ -115,33 +113,50 @@ impl GLWECiphertext<Vec<u8>> {
|
|||||||
|
|
||||||
pub fn keyswitch_scratch_space(
|
pub fn keyswitch_scratch_space(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res_size: usize,
|
out_size: usize,
|
||||||
lhs: usize,
|
out_rank: usize,
|
||||||
rhs: usize,
|
in_size: usize,
|
||||||
rank_in: usize,
|
in_rank: usize,
|
||||||
rank_out: usize,
|
ksk_size: usize,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
|
module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size)
|
||||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
+ (module.vec_znx_big_normalize_tmp_bytes()
|
||||||
)
|
| (module.vmp_apply_tmp_bytes(
|
||||||
|
out_size,
|
||||||
|
in_size,
|
||||||
|
in_size,
|
||||||
|
in_rank + 1,
|
||||||
|
out_rank + 1,
|
||||||
|
ksk_size,
|
||||||
|
) + module.bytes_of_vec_znx_dft(in_size, in_size)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||||
module, res_size, rhs, rank,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_scratch_space(
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, lhs, rhs, rank, rank,
|
out_size: usize,
|
||||||
)
|
in_size: usize,
|
||||||
|
ggsw_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
module.bytes_of_vec_znx_dft(rank + 1, ggsw_size)
|
||||||
|
+ ((module.bytes_of_vec_znx_dft(rank + 1, in_size)
|
||||||
|
+ module.vmp_apply_tmp_bytes(
|
||||||
|
out_size,
|
||||||
|
in_size,
|
||||||
|
in_size, // rows
|
||||||
|
rank + 1, // cols in
|
||||||
|
rank + 1, // cols out
|
||||||
|
ggsw_size,
|
||||||
|
))
|
||||||
|
| module.vec_znx_big_normalize_tmp_bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank)
|
||||||
module, res_size, rhs, rank,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +250,50 @@ where
|
|||||||
VecZnx<DataLhs>: VecZnxToRef,
|
VecZnx<DataLhs>: VecZnxToRef,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0.prod_with_glwe(module, self, lhs, scratch);
|
let basek: usize = self.basek();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||||
|
assert_eq!(self.rank(), rhs.rank_out());
|
||||||
|
assert_eq!(self.basek(), basek);
|
||||||
|
assert_eq!(lhs.basek(), basek);
|
||||||
|
assert_eq!(rhs.n(), module.n());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(lhs.n(), module.n());
|
||||||
|
assert!(
|
||||||
|
scratch.available()
|
||||||
|
>= GLWECiphertext::keyswitch_scratch_space(
|
||||||
|
module,
|
||||||
|
self.size(),
|
||||||
|
self.rank(),
|
||||||
|
lhs.size(),
|
||||||
|
lhs.rank(),
|
||||||
|
rhs.size(),
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols_in: usize = rhs.rank_in();
|
||||||
|
let cols_out: usize = rhs.rank_out() + 1;
|
||||||
|
|
||||||
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||||
|
|
||||||
|
{
|
||||||
|
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||||
|
(0..cols_in).for_each(|col_i| {
|
||||||
|
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||||
|
});
|
||||||
|
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||||
|
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
|
||||||
|
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace<DataRhs>(
|
pub fn keyswitch_inplace<DataRhs>(
|
||||||
@@ -246,7 +304,10 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0.prod_with_glwe_inplace(module, self, scratch);
|
unsafe {
|
||||||
|
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||||
|
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product<DataLhs, DataRhs>(
|
pub fn external_product<DataLhs, DataRhs>(
|
||||||
@@ -259,7 +320,36 @@ where
|
|||||||
VecZnx<DataLhs>: VecZnxToRef,
|
VecZnx<DataLhs>: VecZnxToRef,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_glwe(module, self, lhs, scratch);
|
let basek: usize = self.basek();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(rhs.rank(), lhs.rank());
|
||||||
|
assert_eq!(rhs.rank(), self.rank());
|
||||||
|
assert_eq!(self.basek(), basek);
|
||||||
|
assert_eq!(lhs.basek(), basek);
|
||||||
|
assert_eq!(rhs.n(), module.n());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(lhs.n(), module.n());
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols: usize = rhs.rank() + 1;
|
||||||
|
|
||||||
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
|
||||||
|
|
||||||
|
{
|
||||||
|
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||||
|
(0..cols).for_each(|col_i| {
|
||||||
|
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
|
||||||
|
});
|
||||||
|
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
|
||||||
|
}
|
||||||
|
|
||||||
|
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||||
|
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace<DataRhs>(
|
pub fn external_product_inplace<DataRhs>(
|
||||||
@@ -270,7 +360,10 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_glwe_inplace(module, self, scratch);
|
unsafe {
|
||||||
|
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||||
|
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(
|
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(
|
||||||
|
|||||||
@@ -1,20 +1,13 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
|
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
|
||||||
VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
|
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elem::Infos,
|
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
|
||||||
gglwe_ciphertext::GGLWECiphertext,
|
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
|
||||||
ggsw_ciphertext::GGSWCiphertext,
|
|
||||||
glwe_ciphertext::GLWECiphertext,
|
|
||||||
glwe_plaintext::GLWEPlaintext,
|
|
||||||
keys::SecretKeyFourier,
|
|
||||||
keyswitch_key::GLWESwitchingKey,
|
|
||||||
utils::derive_size,
|
|
||||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct GLWECiphertextFourier<C, B: Backend> {
|
pub struct GLWECiphertextFourier<C, B: Backend> {
|
||||||
@@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
|
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
|
||||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
|
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)),
|
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
|
||||||
basek: log_base2k,
|
basek: basek,
|
||||||
k: log_k,
|
k: k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,33 +85,56 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
|||||||
|
|
||||||
pub fn keyswitch_scratch_space(
|
pub fn keyswitch_scratch_space(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res_size: usize,
|
out_size: usize,
|
||||||
lhs: usize,
|
out_rank: usize,
|
||||||
rhs: usize,
|
in_size: usize,
|
||||||
rank_in: usize,
|
in_rank: usize,
|
||||||
rank_out: usize,
|
ksk_size: usize,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
|
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
|
||||||
)
|
let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size)
|
||||||
|
+ module.vmp_apply_tmp_bytes(
|
||||||
|
out_size,
|
||||||
|
in_size,
|
||||||
|
in_size,
|
||||||
|
in_rank + 1,
|
||||||
|
out_rank + 1,
|
||||||
|
ksk_size,
|
||||||
|
);
|
||||||
|
let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size);
|
||||||
|
let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes();
|
||||||
|
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||||
|
|
||||||
|
res_dft + (vmp | add_a0 | (res_small + normalize))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
|
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||||
module, res_size, rhs, rank,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_scratch_space(
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, lhs, rhs, rank, rank,
|
out_size: usize,
|
||||||
)
|
in_size: usize,
|
||||||
|
ggsw_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||||
|
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
|
||||||
|
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||||
|
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||||
|
|
||||||
|
res_dft + (vmp | (res_small + normalize))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_inplace_scratch_space(
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, rhs, rank,
|
out_size: usize,
|
||||||
)
|
ggsw_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +174,61 @@ where
|
|||||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch);
|
let basek: usize = self.basek();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||||
|
assert_eq!(self.rank(), rhs.rank_out());
|
||||||
|
assert_eq!(self.basek(), basek);
|
||||||
|
assert_eq!(lhs.basek(), basek);
|
||||||
|
assert_eq!(rhs.n(), module.n());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(lhs.n(), module.n());
|
||||||
|
assert!(
|
||||||
|
scratch.available()
|
||||||
|
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
||||||
|
module,
|
||||||
|
self.size(),
|
||||||
|
self.rank(),
|
||||||
|
lhs.size(),
|
||||||
|
lhs.rank(),
|
||||||
|
rhs.size(),
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols_in: usize = rhs.rank_in();
|
||||||
|
let cols_out: usize = rhs.rank_out() + 1;
|
||||||
|
|
||||||
|
// Buffer of the result of VMP in DFT
|
||||||
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||||
|
|
||||||
|
{
|
||||||
|
// Applies VMP
|
||||||
|
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||||
|
(0..cols_in).for_each(|col_i| {
|
||||||
|
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||||
|
});
|
||||||
|
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switches result of VMP outside of DFT
|
||||||
|
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||||
|
|
||||||
|
{
|
||||||
|
// Switches lhs 0-th outside of DFT domain and adds on
|
||||||
|
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
||||||
|
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
||||||
|
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Space fr normalized VMP result outside of DFT domain
|
||||||
|
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||||
|
module.vec_znx_dft(self, i, &res_small, i);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace<DataRhs>(
|
pub fn keyswitch_inplace<DataRhs>(
|
||||||
@@ -169,7 +239,10 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch);
|
unsafe {
|
||||||
|
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||||
|
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product<DataLhs, DataRhs>(
|
pub fn external_product<DataLhs, DataRhs>(
|
||||||
@@ -182,7 +255,37 @@ where
|
|||||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_glwe_fourier(module, self, lhs, scratch);
|
let basek: usize = self.basek();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(rhs.rank(), lhs.rank());
|
||||||
|
assert_eq!(rhs.rank(), self.rank());
|
||||||
|
assert_eq!(self.basek(), basek);
|
||||||
|
assert_eq!(lhs.basek(), basek);
|
||||||
|
assert_eq!(rhs.n(), module.n());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(lhs.n(), module.n());
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols: usize = rhs.rank() + 1;
|
||||||
|
|
||||||
|
// Space for VMP result in DFT domain and high precision
|
||||||
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
||||||
|
|
||||||
|
{
|
||||||
|
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// VMP result in high precision
|
||||||
|
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||||
|
|
||||||
|
// Space for VMP result normalized
|
||||||
|
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||||
|
module.vec_znx_dft(self, i, &res_small, i);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace<DataRhs>(
|
pub fn external_product_inplace<DataRhs>(
|
||||||
@@ -193,7 +296,10 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_glwe_fourier_inplace(module, self, scratch);
|
unsafe {
|
||||||
|
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||||
|
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +353,7 @@ where
|
|||||||
pt.k = pt.k().min(self.k());
|
pt.k = pt.k().min(self.k());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
||||||
where
|
where
|
||||||
GLWECiphertext<DataRes>: VecZnxToMut,
|
GLWECiphertext<DataRes>: VecZnxToMut,
|
||||||
|
|||||||
@@ -43,10 +43,10 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl GLWEPlaintext<Vec<u8>> {
|
impl GLWEPlaintext<Vec<u8>> {
|
||||||
pub fn new<B: Backend>(module: &Module<B>, base2k: usize, k: usize) -> Self {
|
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_vec_znx(1, derive_size(base2k, k)),
|
data: module.new_vec_znx(1, derive_size(basek, k)),
|
||||||
basek: base2k,
|
basek: basek,
|
||||||
k,
|
k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
|
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
|
||||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
|
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -10,7 +10,6 @@ use crate::{
|
|||||||
ggsw_ciphertext::GGSWCiphertext,
|
ggsw_ciphertext::GGSWCiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
keys::{SecretKey, SecretKeyFourier},
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
|
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
|
||||||
@@ -39,6 +38,20 @@ impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T, B: Backend> GLWESwitchingKey<T, B> {
|
||||||
|
pub fn rank(&self) -> usize {
|
||||||
|
self.0.data.cols_out() - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rank_in(&self) -> usize {
|
||||||
|
self.0.data.cols_in()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rank_out(&self) -> usize {
|
||||||
|
self.0.data.cols_out() - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
|
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
|
||||||
where
|
where
|
||||||
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
|
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
|
||||||
@@ -131,33 +144,46 @@ where
|
|||||||
impl GLWESwitchingKey<Vec<u8>, FFT64> {
|
impl GLWESwitchingKey<Vec<u8>, FFT64> {
|
||||||
pub fn keyswitch_scratch_space(
|
pub fn keyswitch_scratch_space(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res_size: usize,
|
out_size: usize,
|
||||||
lhs: usize,
|
out_rank: usize,
|
||||||
rhs: usize,
|
in_size: usize,
|
||||||
rank_in: usize,
|
in_rank: usize,
|
||||||
rank_out: usize,
|
ksk_size: usize,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size);
|
||||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||||
)
|
let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size);
|
||||||
|
tmp_in + tmp_out + ksk
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
|
let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||||
module, res_size, rhs, rank,
|
let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size);
|
||||||
)
|
tmp + ksk
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_scratch_space(
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, lhs, rhs, rank, rank,
|
out_size: usize,
|
||||||
)
|
in_size: usize,
|
||||||
|
ggsw_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||||
|
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||||
|
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||||
|
tmp_in + tmp_out + ggsw
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
pub fn external_product_inplace_scratch_space(
|
||||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
module: &Module<FFT64>,
|
||||||
module, res_size, rhs, rank,
|
out_size: usize,
|
||||||
)
|
ggsw_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||||
|
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||||
|
tmp + ggsw
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,8 +201,62 @@ where
|
|||||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0
|
#[cfg(debug_assertions)]
|
||||||
.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in(),
|
||||||
|
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lhs.rank_out(),
|
||||||
|
rhs.rank_in(),
|
||||||
|
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out(),
|
||||||
|
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||||
|
|
||||||
|
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_in_data,
|
||||||
|
basek: lhs.basek(),
|
||||||
|
k: lhs.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_out_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank_in()).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||||
|
tmp_out.keyswitch(module, &tmp_in, rhs, scratch2);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
tmp_out.data.zero();
|
||||||
|
|
||||||
|
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||||
|
(0..self.rank_in()).for_each(|col_j| {
|
||||||
|
self.set_row(module, row_i, col_j, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keyswitch_inplace<DataRhs>(
|
pub fn keyswitch_inplace<DataRhs>(
|
||||||
@@ -187,8 +267,32 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.0
|
#[cfg(debug_assertions)]
|
||||||
.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out(),
|
||||||
|
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank_in()).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
self.get_row(module, row_j, col_i, &mut tmp);
|
||||||
|
tmp.keyswitch_inplace(module, rhs, scratch1);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product<DataLhs, DataRhs>(
|
pub fn external_product<DataLhs, DataRhs>(
|
||||||
@@ -201,7 +305,62 @@ where
|
|||||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in(),
|
||||||
|
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lhs.rank_out(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ksk_in output rank: {} != ggsw rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ksk_out output rank: {} != ggsw rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||||
|
|
||||||
|
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_in_data,
|
||||||
|
basek: lhs.basek(),
|
||||||
|
k: lhs.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_out_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank_in()).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||||
|
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
tmp_out.data.zero();
|
||||||
|
|
||||||
|
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||||
|
(0..self.rank_in()).for_each(|col_j| {
|
||||||
|
self.set_row(module, row_i, col_j, &tmp_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn external_product_inplace<DataRhs>(
|
pub fn external_product_inplace<DataRhs>(
|
||||||
@@ -212,6 +371,31 @@ where
|
|||||||
) where
|
) where
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ksk_out output rank: {} != ggsw rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||||
|
|
||||||
|
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank_in()).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
self.get_row(module, row_j, col_i, &mut tmp);
|
||||||
|
tmp.external_product_inplace(module, rhs, scratch1);
|
||||||
|
self.set_row(module, row_j, col_i, &tmp);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,5 @@ pub mod keyswitch_key;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_fft64;
|
mod test_fft64;
|
||||||
mod utils;
|
mod utils;
|
||||||
pub mod vec_glwe_product;
|
|
||||||
|
|
||||||
pub(crate) const SIX_SIGMA: f64 = 6.0;
|
pub(crate) const SIX_SIGMA: f64 = 6.0;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,575 +1,345 @@
|
|||||||
// use base2k::{
|
use base2k::{
|
||||||
// FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
|
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
|
||||||
// VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
|
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
|
||||||
// };
|
};
|
||||||
// use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
//
|
|
||||||
// use crate::{
|
use crate::{
|
||||||
// elem::{GetRow, Infos},
|
elem::{GetRow, Infos},
|
||||||
// ggsw_ciphertext::GGSWCiphertext,
|
ggsw_ciphertext::GGSWCiphertext,
|
||||||
// glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
// glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
// keys::{SecretKey, SecretKeyFourier},
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
// keyswitch_key::GLWESwitchingKey,
|
keyswitch_key::GLWESwitchingKey,
|
||||||
// test_fft64::gglwe::noise_grlwe_rlwe_product,
|
};
|
||||||
// };
|
|
||||||
//
|
#[test]
|
||||||
// #[test]
|
fn encrypt_sk() {
|
||||||
// fn encrypt_sk() {
|
(1..4).for_each(|rank| {
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
println!("test encrypt_sk rank: {}", rank);
|
||||||
// let log_base2k: usize = 8;
|
test_encrypt_sk(11, 8, 54, 3.2, rank);
|
||||||
// let log_k_ct: usize = 54;
|
});
|
||||||
// let rows: usize = 4;
|
}
|
||||||
// let rank: usize = 1;
|
|
||||||
//
|
#[test]
|
||||||
// let sigma: f64 = 3.2;
|
fn external_product() {
|
||||||
//
|
(1..4).for_each(|rank| {
|
||||||
// let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank);
|
println!("test external_product rank: {}", rank);
|
||||||
// let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_ct);
|
test_external_product(12, 12, 60, rank, 3.2);
|
||||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_ct);
|
});
|
||||||
// let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
}
|
||||||
//
|
|
||||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
#[test]
|
||||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
fn external_product_inplace() {
|
||||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
(1..4).for_each(|rank| {
|
||||||
//
|
println!("test external_product rank: {}", rank);
|
||||||
// pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
test_external_product_inplace(12, 15, 60, rank, 3.2);
|
||||||
//
|
});
|
||||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
}
|
||||||
// GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
|
|
||||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
|
fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) {
|
||||||
// );
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
//
|
|
||||||
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||||
// sk.fill_ternary_prob(0.5, &mut source_xs);
|
|
||||||
//
|
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
// sk_dft.dft(&module, &sk);
|
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
//
|
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
// ct.encrypt_sk(
|
|
||||||
// &module,
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
// &pt_scalar,
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
// &sk_dft,
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
// &mut source_xa,
|
|
||||||
// &mut source_xe,
|
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
// );
|
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
|
||||||
//
|
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
|
||||||
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank);
|
);
|
||||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
|
||||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
//
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
// (0..ct.rank()).for_each(|col_j| {
|
|
||||||
// (0..ct.rows()).for_each(|row_i| {
|
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
sk_dft.dft(&module, &sk);
|
||||||
//
|
|
||||||
// if col_j == 1 {
|
ct.encrypt_sk(
|
||||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
&module,
|
||||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
&pt_scalar,
|
||||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
&sk_dft,
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
&mut source_xa,
|
||||||
// }
|
&mut source_xe,
|
||||||
//
|
sigma,
|
||||||
// ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
scratch.borrow(),
|
||||||
//
|
);
|
||||||
// ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
|
||||||
//
|
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||||
// module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||||
//
|
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||||
// let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
|
||||||
// assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
(0..ct.rank() + 1).for_each(|col_j| {
|
||||||
//
|
(0..ct.rows()).for_each(|row_i| {
|
||||||
// pt_want.data.zero();
|
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||||
// });
|
|
||||||
// });
|
// mul with sk[col_j-1]
|
||||||
// }
|
if col_j > 0 {
|
||||||
//
|
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||||
// #[test]
|
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||||
// fn keyswitch() {
|
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||||
// let log_base2k: usize = 12;
|
}
|
||||||
// let log_k_grlwe: usize = 60;
|
|
||||||
// let log_k_rgsw_in: usize = 45;
|
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||||
// let log_k_rgsw_out: usize = 45;
|
|
||||||
// let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k;
|
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||||
//
|
|
||||||
// let rank: usize = 1;
|
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||||
//
|
|
||||||
// let sigma: f64 = 3.2;
|
let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2();
|
||||||
//
|
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||||
// let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> =
|
|
||||||
// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank);
|
pt_want.data.zero();
|
||||||
// let mut ct_rgsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank);
|
});
|
||||||
// let mut ct_rgsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank);
|
});
|
||||||
// let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
}
|
||||||
//
|
|
||||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
|
||||||
//
|
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||||
// Random input plaintext
|
|
||||||
// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs);
|
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
//
|
let mut ct_ggsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
let mut ct_ggsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
|
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size())
|
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size())
|
|
||||||
// | GGSWCiphertext::keyswitch_scratch_space(
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
// &module,
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
// ct_rgsw_out.size(),
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
// ct_rgsw_in.size(),
|
|
||||||
// ct_grlwe.size(),
|
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||||
// ),
|
|
||||||
// );
|
let k: usize = 1;
|
||||||
//
|
|
||||||
// let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||||
// sk0.fill_ternary_prob(0.5, &mut source_xs);
|
|
||||||
//
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
// let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
|
||||||
// sk0_dft.dft(&module, &sk0);
|
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
||||||
//
|
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
|
||||||
// let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
| GGSWCiphertext::external_product_scratch_space(
|
||||||
// sk1.fill_ternary_prob(0.5, &mut source_xs);
|
&module,
|
||||||
//
|
ct_ggsw_lhs_out.size(),
|
||||||
// let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
ct_ggsw_lhs_in.size(),
|
||||||
// sk1_dft.dft(&module, &sk1);
|
ct_ggsw_rhs.size(),
|
||||||
//
|
rank,
|
||||||
// ct_grlwe.encrypt_sk(
|
),
|
||||||
// &module,
|
);
|
||||||
// &sk0.data,
|
|
||||||
// &sk1_dft,
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
// &mut source_xa,
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
// &mut source_xe,
|
|
||||||
// sigma,
|
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
// scratch.borrow(),
|
sk_dft.dft(&module, &sk);
|
||||||
// );
|
|
||||||
//
|
ct_ggsw_rhs.encrypt_sk(
|
||||||
// ct_rgsw_in.encrypt_sk(
|
&module,
|
||||||
// &module,
|
&pt_ggsw_rhs,
|
||||||
// &pt_rgsw,
|
&sk_dft,
|
||||||
// &sk0_dft,
|
&mut source_xa,
|
||||||
// &mut source_xa,
|
&mut source_xe,
|
||||||
// &mut source_xe,
|
sigma,
|
||||||
// sigma,
|
scratch.borrow(),
|
||||||
// scratch.borrow(),
|
);
|
||||||
// );
|
|
||||||
//
|
ct_ggsw_lhs_in.encrypt_sk(
|
||||||
// ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow());
|
&module,
|
||||||
//
|
&pt_ggsw_lhs,
|
||||||
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
|
&sk_dft,
|
||||||
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank);
|
&mut source_xa,
|
||||||
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out);
|
&mut source_xe,
|
||||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size());
|
sigma,
|
||||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size());
|
scratch.borrow(),
|
||||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out);
|
);
|
||||||
//
|
|
||||||
// (0..ct_rgsw_out.rank()).for_each(|col_j| {
|
ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow());
|
||||||
// (0..ct_rgsw_out.rows()).for_each(|row_i| {
|
|
||||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0);
|
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||||
//
|
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
// if col_j == 1 {
|
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
|
||||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
|
||||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0);
|
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||||
// }
|
|
||||||
//
|
(0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
|
||||||
// ct_rgsw_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
(0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
|
||||||
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow());
|
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||||
//
|
|
||||||
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
if col_j > 0 {
|
||||||
//
|
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||||
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||||
// let noise_want: f64 = noise_grlwe_rlwe_product(
|
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||||
// module.n() as f64,
|
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||||
// log_base2k,
|
}
|
||||||
// 0.5,
|
|
||||||
// 0.5,
|
ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||||
// 0f64,
|
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
// sigma * sigma,
|
|
||||||
// 0f64,
|
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||||
// log_k_grlwe,
|
|
||||||
// log_k_grlwe,
|
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||||
// );
|
|
||||||
//
|
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||||
// assert!(
|
let var_gct_err_rhs: f64 = 0f64;
|
||||||
// (noise_have - noise_want).abs() <= 0.2,
|
|
||||||
// "have: {} want: {}",
|
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||||
// noise_have,
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
// noise_want
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
// );
|
|
||||||
//
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
// pt_want.data.zero();
|
module.n() as f64,
|
||||||
// });
|
basek,
|
||||||
// });
|
0.5,
|
||||||
// }
|
var_msg,
|
||||||
//
|
var_a0_err,
|
||||||
// #[test]
|
var_a1_err,
|
||||||
// fn keyswitch_inplace() {
|
var_gct_err_lhs,
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
var_gct_err_rhs,
|
||||||
// let log_base2k: usize = 12;
|
rank as f64,
|
||||||
// let log_k_grlwe: usize = 60;
|
k_ggsw,
|
||||||
// let log_k_rgsw: usize = 45;
|
k_ggsw,
|
||||||
// let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k;
|
);
|
||||||
// let rank: usize = 1;
|
|
||||||
//
|
assert!(
|
||||||
// let sigma: f64 = 3.2;
|
(noise_have - noise_want).abs() <= 0.1,
|
||||||
//
|
"have: {} want: {}",
|
||||||
// let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> =
|
noise_have,
|
||||||
// GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank);
|
noise_want
|
||||||
// let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank);
|
);
|
||||||
// let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
|
||||||
//
|
pt_want.data.zero();
|
||||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
});
|
||||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
});
|
||||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
}
|
||||||
//
|
|
||||||
// Random input plaintext
|
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||||
// pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
//
|
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
|
||||||
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
|
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size())
|
let mut ct_ggsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||||
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
// | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()),
|
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
// );
|
|
||||||
//
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
// let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
// sk0.fill_ternary_prob(0.5, &mut source_xs);
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
//
|
|
||||||
// let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||||
// sk0_dft.dft(&module, &sk0);
|
|
||||||
//
|
let k: usize = 1;
|
||||||
// let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
|
||||||
// sk1.fill_ternary_prob(0.5, &mut source_xs);
|
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||||
//
|
|
||||||
// let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
// sk1_dft.dft(&module, &sk1);
|
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
|
||||||
//
|
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size())
|
||||||
// ct_grlwe.encrypt_sk(
|
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size())
|
||||||
// &module,
|
| GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank),
|
||||||
// &sk0.data,
|
);
|
||||||
// &sk1_dft,
|
|
||||||
// &mut source_xa,
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
// &mut source_xe,
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
// );
|
sk_dft.dft(&module, &sk);
|
||||||
//
|
|
||||||
// ct_rgsw.encrypt_sk(
|
ct_ggsw_rhs.encrypt_sk(
|
||||||
// &module,
|
&module,
|
||||||
// &pt_rgsw,
|
&pt_ggsw_rhs,
|
||||||
// &sk0_dft,
|
&sk_dft,
|
||||||
// &mut source_xa,
|
&mut source_xa,
|
||||||
// &mut source_xe,
|
&mut source_xe,
|
||||||
// sigma,
|
sigma,
|
||||||
// scratch.borrow(),
|
scratch.borrow(),
|
||||||
// );
|
);
|
||||||
//
|
|
||||||
// ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow());
|
ct_ggsw_lhs.encrypt_sk(
|
||||||
//
|
&module,
|
||||||
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
|
&pt_ggsw_lhs,
|
||||||
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank);
|
&sk_dft,
|
||||||
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw);
|
&mut source_xa,
|
||||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size());
|
&mut source_xe,
|
||||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size());
|
sigma,
|
||||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw);
|
scratch.borrow(),
|
||||||
//
|
);
|
||||||
// (0..ct_rgsw.rank()).for_each(|col_j| {
|
|
||||||
// (0..ct_rgsw.rows()).for_each(|row_i| {
|
ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow());
|
||||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0);
|
|
||||||
//
|
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||||
// if col_j == 1 {
|
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size());
|
||||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk0_dft, 0);
|
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size());
|
||||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
|
||||||
// }
|
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||||
//
|
|
||||||
// ct_rgsw.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
(0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| {
|
||||||
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk1_dft, scratch.borrow());
|
(0..ct_ggsw_lhs.rows()).for_each(|row_i| {
|
||||||
//
|
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||||
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
|
||||||
//
|
if col_j > 0 {
|
||||||
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||||
// let noise_want: f64 = noise_grlwe_rlwe_product(
|
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||||
// module.n() as f64,
|
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||||
// log_base2k,
|
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||||
// 0.5,
|
}
|
||||||
// 0.5,
|
|
||||||
// 0f64,
|
ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||||
// sigma * sigma,
|
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
// 0f64,
|
|
||||||
// log_k_grlwe,
|
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||||
// log_k_grlwe,
|
|
||||||
// );
|
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||||
//
|
|
||||||
// assert!(
|
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||||
// (noise_have - noise_want).abs() <= 0.2,
|
let var_gct_err_rhs: f64 = 0f64;
|
||||||
// "have: {} want: {}",
|
|
||||||
// noise_have,
|
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||||
// noise_want
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
// );
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
//
|
|
||||||
// pt_want.data.zero();
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
// });
|
module.n() as f64,
|
||||||
// });
|
basek,
|
||||||
// }
|
0.5,
|
||||||
//
|
var_msg,
|
||||||
// #[test]
|
var_a0_err,
|
||||||
// fn external_product() {
|
var_a1_err,
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
var_gct_err_lhs,
|
||||||
// let log_base2k: usize = 12;
|
var_gct_err_rhs,
|
||||||
// let log_k_rgsw_rhs: usize = 60;
|
rank as f64,
|
||||||
// let log_k_rgsw_lhs_in: usize = 45;
|
k_ggsw,
|
||||||
// let log_k_rgsw_lhs_out: usize = 45;
|
k_ggsw,
|
||||||
// let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k;
|
);
|
||||||
// let rank: usize = 1;
|
|
||||||
//
|
assert!(
|
||||||
// let sigma: f64 = 3.2;
|
(noise_have - noise_want).abs() <= 0.1,
|
||||||
//
|
"have: {} want: {}",
|
||||||
// let mut ct_rgsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank);
|
noise_have,
|
||||||
// let mut ct_rgsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> =
|
noise_want
|
||||||
// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank);
|
);
|
||||||
// let mut ct_rgsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> =
|
|
||||||
// GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank);
|
pt_want.data.zero();
|
||||||
// let mut pt_rgsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
});
|
||||||
// let mut pt_rgsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
});
|
||||||
//
|
}
|
||||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
pub(crate) fn noise_ggsw_product(
|
||||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
|
||||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
|
||||||
//
|
|
||||||
// Random input plaintext
|
|
||||||
// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
|
||||||
//
|
|
||||||
// let k: usize = 1;
|
|
||||||
//
|
|
||||||
// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
|
||||||
//
|
|
||||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
|
||||||
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size())
|
|
||||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size())
|
|
||||||
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size())
|
|
||||||
// | GGSWCiphertext::external_product_scratch_space(
|
|
||||||
// &module,
|
|
||||||
// ct_rgsw_lhs_out.size(),
|
|
||||||
// ct_rgsw_lhs_in.size(),
|
|
||||||
// ct_rgsw_rhs.size(),
|
|
||||||
// ),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
|
||||||
// sk.fill_ternary_prob(0.5, &mut source_xs);
|
|
||||||
//
|
|
||||||
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
|
||||||
// sk_dft.dft(&module, &sk);
|
|
||||||
//
|
|
||||||
// ct_rgsw_rhs.encrypt_sk(
|
|
||||||
// &module,
|
|
||||||
// &pt_rgsw_rhs,
|
|
||||||
// &sk_dft,
|
|
||||||
// &mut source_xa,
|
|
||||||
// &mut source_xe,
|
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs_in.encrypt_sk(
|
|
||||||
// &module,
|
|
||||||
// &pt_rgsw_lhs,
|
|
||||||
// &sk_dft,
|
|
||||||
// &mut source_xa,
|
|
||||||
// &mut source_xe,
|
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow());
|
|
||||||
//
|
|
||||||
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
|
|
||||||
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank);
|
|
||||||
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out);
|
|
||||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size());
|
|
||||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size());
|
|
||||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out);
|
|
||||||
//
|
|
||||||
// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0);
|
|
||||||
//
|
|
||||||
// (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| {
|
|
||||||
// (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| {
|
|
||||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0);
|
|
||||||
//
|
|
||||||
// if col_j == 1 {
|
|
||||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
|
||||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
|
||||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
|
||||||
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
|
||||||
//
|
|
||||||
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
|
||||||
//
|
|
||||||
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
|
||||||
//
|
|
||||||
// let var_gct_err_lhs: f64 = sigma * sigma;
|
|
||||||
// let var_gct_err_rhs: f64 = 0f64;
|
|
||||||
//
|
|
||||||
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
|
||||||
// let var_a0_err: f64 = sigma * sigma;
|
|
||||||
// let var_a1_err: f64 = 1f64 / 12f64;
|
|
||||||
//
|
|
||||||
// let noise_want: f64 = noise_rgsw_product(
|
|
||||||
// module.n() as f64,
|
|
||||||
// log_base2k,
|
|
||||||
// 0.5,
|
|
||||||
// var_msg,
|
|
||||||
// var_a0_err,
|
|
||||||
// var_a1_err,
|
|
||||||
// var_gct_err_lhs,
|
|
||||||
// var_gct_err_rhs,
|
|
||||||
// log_k_rgsw_lhs_in,
|
|
||||||
// log_k_rgsw_rhs,
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// assert!(
|
|
||||||
// (noise_have - noise_want).abs() <= 0.1,
|
|
||||||
// "have: {} want: {}",
|
|
||||||
// noise_have,
|
|
||||||
// noise_want
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// pt_want.data.zero();
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[test]
|
|
||||||
// fn external_product_inplace() {
|
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
|
||||||
// let log_base2k: usize = 12;
|
|
||||||
// let log_k_rgsw_rhs: usize = 60;
|
|
||||||
// let log_k_rgsw_lhs: usize = 45;
|
|
||||||
// let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k;
|
|
||||||
// let rank: usize = 1;
|
|
||||||
//
|
|
||||||
// let sigma: f64 = 3.2;
|
|
||||||
//
|
|
||||||
// let mut ct_rgsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank);
|
|
||||||
// let mut ct_rgsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank);
|
|
||||||
// let mut pt_rgsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
|
||||||
// let mut pt_rgsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
|
||||||
//
|
|
||||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
|
||||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
|
||||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
|
||||||
//
|
|
||||||
// Random input plaintext
|
|
||||||
// pt_rgsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
|
||||||
//
|
|
||||||
// let k: usize = 1;
|
|
||||||
//
|
|
||||||
// pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
|
||||||
//
|
|
||||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
|
||||||
// GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size())
|
|
||||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size())
|
|
||||||
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size())
|
|
||||||
// | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
|
||||||
// sk.fill_ternary_prob(0.5, &mut source_xs);
|
|
||||||
//
|
|
||||||
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
|
||||||
// sk_dft.dft(&module, &sk);
|
|
||||||
//
|
|
||||||
// ct_rgsw_rhs.encrypt_sk(
|
|
||||||
// &module,
|
|
||||||
// &pt_rgsw_rhs,
|
|
||||||
// &sk_dft,
|
|
||||||
// &mut source_xa,
|
|
||||||
// &mut source_xe,
|
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs.encrypt_sk(
|
|
||||||
// &module,
|
|
||||||
// &pt_rgsw_lhs,
|
|
||||||
// &sk_dft,
|
|
||||||
// &mut source_xa,
|
|
||||||
// &mut source_xe,
|
|
||||||
// sigma,
|
|
||||||
// scratch.borrow(),
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow());
|
|
||||||
//
|
|
||||||
// let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
|
|
||||||
// GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank);
|
|
||||||
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs);
|
|
||||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size());
|
|
||||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size());
|
|
||||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs);
|
|
||||||
//
|
|
||||||
// module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0);
|
|
||||||
//
|
|
||||||
// (0..ct_rgsw_lhs.rank()).for_each(|col_j| {
|
|
||||||
// (0..ct_rgsw_lhs.rows()).for_each(|row_i| {
|
|
||||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0);
|
|
||||||
//
|
|
||||||
// if col_j == 1 {
|
|
||||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
|
||||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
|
||||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// ct_rgsw_lhs.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
|
||||||
// ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
|
||||||
//
|
|
||||||
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
|
||||||
//
|
|
||||||
// let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
|
||||||
//
|
|
||||||
// let var_gct_err_lhs: f64 = sigma * sigma;
|
|
||||||
// let var_gct_err_rhs: f64 = 0f64;
|
|
||||||
//
|
|
||||||
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
|
||||||
// let var_a0_err: f64 = sigma * sigma;
|
|
||||||
// let var_a1_err: f64 = 1f64 / 12f64;
|
|
||||||
//
|
|
||||||
// let noise_want: f64 = noise_rgsw_product(
|
|
||||||
// module.n() as f64,
|
|
||||||
// log_base2k,
|
|
||||||
// 0.5,
|
|
||||||
// var_msg,
|
|
||||||
// var_a0_err,
|
|
||||||
// var_a1_err,
|
|
||||||
// var_gct_err_lhs,
|
|
||||||
// var_gct_err_rhs,
|
|
||||||
// log_k_rgsw_lhs,
|
|
||||||
// log_k_rgsw_rhs,
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// assert!(
|
|
||||||
// (noise_have - noise_want).abs() <= 0.1,
|
|
||||||
// "have: {} want: {}",
|
|
||||||
// noise_have,
|
|
||||||
// noise_want
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// pt_want.data.zero();
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
pub(crate) fn noise_ggsw_gglwe_product(
|
|
||||||
n: f64,
|
n: f64,
|
||||||
log_base2k: usize,
|
basek: usize,
|
||||||
var_xs: f64,
|
var_xs: f64,
|
||||||
var_msg: f64,
|
var_msg: f64,
|
||||||
var_a0_err: f64,
|
var_a0_err: f64,
|
||||||
@@ -581,12 +351,12 @@ pub(crate) fn noise_ggsw_gglwe_product(
|
|||||||
b_logq: usize,
|
b_logq: usize,
|
||||||
) -> f64 {
|
) -> f64 {
|
||||||
let a_logq: usize = a_logq.min(b_logq);
|
let a_logq: usize = a_logq.min(b_logq);
|
||||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
let a_cols: usize = (a_logq + basek - 1) / basek;
|
||||||
|
|
||||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||||
|
|
||||||
let base: f64 = (1 << (log_base2k)) as f64;
|
let base: f64 = (1 << (basek)) as f64;
|
||||||
let var_base: f64 = base * base / 12f64;
|
let var_base: f64 = base * base / 12f64;
|
||||||
|
|
||||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
keys::{GLWEPublicKey, SecretKey, SecretKeyFourier},
|
keys::{GLWEPublicKey, SecretKey, SecretKeyFourier},
|
||||||
keyswitch_key::GLWESwitchingKey,
|
keyswitch_key::GLWESwitchingKey,
|
||||||
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product},
|
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -498,7 +498,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
|
|||||||
let var_a0_err: f64 = sigma * sigma;
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
let var_a1_err: f64 = 1f64 / 12f64;
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
let noise_want: f64 = noise_ggsw_gglwe_product(
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
module.n() as f64,
|
module.n() as f64,
|
||||||
basek,
|
basek,
|
||||||
0.5,
|
0.5,
|
||||||
@@ -595,7 +595,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
|
|||||||
let var_a0_err: f64 = sigma * sigma;
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
let var_a1_err: f64 = 1f64 / 12f64;
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
let noise_want: f64 = noise_ggsw_gglwe_product(
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
module.n() as f64,
|
module.n() as f64,
|
||||||
basek,
|
basek,
|
||||||
0.5,
|
0.5,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use crate::{
|
|||||||
glwe_plaintext::GLWEPlaintext,
|
glwe_plaintext::GLWEPlaintext,
|
||||||
keys::{SecretKey, SecretKeyFourier},
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
keyswitch_key::GLWESwitchingKey,
|
keyswitch_key::GLWESwitchingKey,
|
||||||
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_gglwe_product},
|
test_fft64::{gglwe::noise_gglwe_product, ggsw::noise_ggsw_product},
|
||||||
};
|
};
|
||||||
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut};
|
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
@@ -322,7 +322,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
|
|||||||
let var_a0_err: f64 = sigma * sigma;
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
let var_a1_err: f64 = 1f64 / 12f64;
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
let noise_want: f64 = noise_ggsw_gglwe_product(
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
module.n() as f64,
|
module.n() as f64,
|
||||||
basek,
|
basek,
|
||||||
0.5,
|
0.5,
|
||||||
@@ -422,7 +422,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
|
|||||||
let var_a0_err: f64 = sigma * sigma;
|
let var_a0_err: f64 = sigma * sigma;
|
||||||
let var_a1_err: f64 = 1f64 / 12f64;
|
let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
|
||||||
let noise_want: f64 = noise_ggsw_gglwe_product(
|
let noise_want: f64 = noise_ggsw_product(
|
||||||
module.n() as f64,
|
module.n() as f64,
|
||||||
basek,
|
basek,
|
||||||
0.5,
|
0.5,
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize {
|
pub(crate) fn derive_size(basek: usize, k: usize) -> usize {
|
||||||
(log_k + log_base2k - 1) / log_base2k
|
(k + basek - 1) / basek
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,218 +0,0 @@
|
|||||||
use base2k::{
|
|
||||||
FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef,
|
|
||||||
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
elem::{GetRow, Infos, SetRow},
|
|
||||||
glwe_ciphertext::GLWECiphertext,
|
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub(crate) trait VecGLWEProductScratchSpace {
|
|
||||||
fn prod_with_glwe_scratch_space(
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res_size: usize,
|
|
||||||
lhs: usize,
|
|
||||||
rhs: usize,
|
|
||||||
rank_in: usize,
|
|
||||||
rank_out: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
fn prod_with_glwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
|
||||||
Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs, rank, rank)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_glwe_fourier_scratch_space(
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res_size: usize,
|
|
||||||
lhs: usize,
|
|
||||||
rhs: usize,
|
|
||||||
rank_in: usize,
|
|
||||||
rank_out: usize,
|
|
||||||
) -> usize {
|
|
||||||
(Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out) | module.vec_znx_idft_tmp_bytes())
|
|
||||||
+ module.bytes_of_vec_znx(rank_in + 1, lhs)
|
|
||||||
+ module.bytes_of_vec_znx(rank_out + 1, res_size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_glwe_fourier_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
|
||||||
(Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs, rank) | module.vec_znx_idft_tmp_bytes())
|
|
||||||
+ module.bytes_of_vec_znx(rank + 1, res_size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_vec_glwe_scratch_space(
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res_size: usize,
|
|
||||||
lhs: usize,
|
|
||||||
rhs: usize,
|
|
||||||
rank_in: usize,
|
|
||||||
rank_out: usize,
|
|
||||||
) -> usize {
|
|
||||||
Self::prod_with_glwe_fourier_scratch_space(module, res_size, lhs, rhs, rank_in, rank_out)
|
|
||||||
+ module.bytes_of_vec_znx_dft(rank_in + 1, lhs)
|
|
||||||
+ module.bytes_of_vec_znx_dft(rank_out + 1, res_size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_vec_glwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
|
||||||
Self::prod_with_glwe_fourier_inplace_scratch_space(module, res_size, rhs, rank)
|
|
||||||
+ module.bytes_of_vec_znx_dft(rank + 1, res_size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) trait VecGLWEProduct: Infos {
|
|
||||||
fn prod_with_glwe<MUT, REF>(
|
|
||||||
&self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res: &mut GLWECiphertext<MUT>,
|
|
||||||
a: &GLWECiphertext<REF>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
VecZnx<MUT>: VecZnxToMut,
|
|
||||||
VecZnx<REF>: VecZnxToRef;
|
|
||||||
|
|
||||||
fn prod_with_glwe_inplace<MUT>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<MUT>, scratch: &mut Scratch)
|
|
||||||
where
|
|
||||||
VecZnx<MUT>: VecZnxToMut + VecZnxToRef,
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
let res_ptr: *mut GLWECiphertext<MUT> = res as *mut GLWECiphertext<MUT>; // This is ok because [Self::mul_rlwe] only updates res at the end.
|
|
||||||
self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_glwe_fourier<MUT, REF>(
|
|
||||||
&self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res: &mut GLWECiphertextFourier<MUT, FFT64>,
|
|
||||||
a: &GLWECiphertextFourier<REF, FFT64>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
VecZnxDft<MUT, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
|
|
||||||
VecZnxDft<REF, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
|
|
||||||
{
|
|
||||||
let log_base2k: usize = self.basek();
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(res.basek(), log_base2k);
|
|
||||||
assert_eq!(self.n(), module.n());
|
|
||||||
assert_eq!(res.n(), module.n());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (a_data, scratch_1) = scratch.tmp_vec_znx(module, a.rank() + 1, a.size());
|
|
||||||
|
|
||||||
let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
|
||||||
data: a_data,
|
|
||||||
basek: a.basek(),
|
|
||||||
k: a.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
a.idft(module, &mut a_idft, scratch_1);
|
|
||||||
|
|
||||||
let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, res.rank() + 1, res.size());
|
|
||||||
|
|
||||||
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
|
||||||
data: res_data,
|
|
||||||
basek: res.basek(),
|
|
||||||
k: res.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2);
|
|
||||||
|
|
||||||
res_idft.dft(module, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_glwe_fourier_inplace<MUT>(
|
|
||||||
&self,
|
|
||||||
module: &Module<FFT64>,
|
|
||||||
res: &mut GLWECiphertextFourier<MUT, FFT64>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) where
|
|
||||||
VecZnxDft<MUT, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
|
|
||||||
{
|
|
||||||
let log_base2k: usize = self.basek();
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(res.basek(), log_base2k);
|
|
||||||
assert_eq!(self.n(), module.n());
|
|
||||||
assert_eq!(res.n(), module.n());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (res_data, scratch_1) = scratch.tmp_vec_znx(module, res.rank() + 1, res.size());
|
|
||||||
|
|
||||||
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
|
||||||
data: res_data,
|
|
||||||
basek: res.basek(),
|
|
||||||
k: res.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
res.idft(module, &mut res_idft, scratch_1);
|
|
||||||
|
|
||||||
self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1);
|
|
||||||
|
|
||||||
res_idft.dft(module, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_vec_glwe<RES, LHS>(&self, module: &Module<FFT64>, res: &mut RES, a: &LHS, scratch: &mut Scratch)
|
|
||||||
where
|
|
||||||
LHS: GetRow<FFT64> + Infos,
|
|
||||||
RES: SetRow<FFT64> + Infos,
|
|
||||||
{
|
|
||||||
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, a.cols(), a.size());
|
|
||||||
|
|
||||||
let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
|
||||||
data: tmp_row_data,
|
|
||||||
basek: a.basek(),
|
|
||||||
k: a.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, res.cols(), res.size());
|
|
||||||
|
|
||||||
let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
|
||||||
data: tmp_res_data,
|
|
||||||
basek: res.basek(),
|
|
||||||
k: res.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let min_rows: usize = res.rows().min(a.rows());
|
|
||||||
|
|
||||||
(0..res.rows()).for_each(|row_i| {
|
|
||||||
(0..res.cols()).for_each(|col_j| {
|
|
||||||
a.get_row(module, row_i, col_j, &mut tmp_a_row);
|
|
||||||
self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2);
|
|
||||||
res.set_row(module, row_i, col_j, &tmp_res_row);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
tmp_res_row.data.zero();
|
|
||||||
|
|
||||||
(min_rows..res.rows()).for_each(|row_i| {
|
|
||||||
(0..self.cols()).for_each(|col_j| {
|
|
||||||
res.set_row(module, row_i, col_j, &tmp_res_row);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prod_with_vec_glwe_inplace<RES>(&self, module: &Module<FFT64>, res: &mut RES, scratch: &mut Scratch)
|
|
||||||
where
|
|
||||||
RES: GetRow<FFT64> + SetRow<FFT64> + Infos,
|
|
||||||
{
|
|
||||||
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, res.cols(), res.size());
|
|
||||||
|
|
||||||
let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
|
||||||
data: tmp_row_data,
|
|
||||||
basek: res.basek(),
|
|
||||||
k: res.k(),
|
|
||||||
};
|
|
||||||
|
|
||||||
(0..res.rows()).for_each(|row_i| {
|
|
||||||
(0..res.cols()).for_each(|col_j| {
|
|
||||||
res.get_row(module, row_i, col_j, &mut tmp_row);
|
|
||||||
self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1);
|
|
||||||
res.set_row(module, row_i, col_j, &tmp_row);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user