mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -1,35 +1,31 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
utils::derive_size,
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
};
|
||||
|
||||
pub struct GGSWCiphertext<C, B: Backend> {
|
||||
pub data: MatZnxDft<C, B>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self {
|
||||
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k: k,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,11 +38,11 @@ impl<T, B: Backend> Infos for GGSWCiphertext<T, B> {
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.log_base2k
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.log_k
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,35 +78,28 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
||||
)
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,7 +129,7 @@ where
|
||||
}
|
||||
|
||||
let size: usize = self.size();
|
||||
let log_base2k: usize = self.basek();
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
@@ -149,20 +138,20 @@ where
|
||||
|
||||
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
||||
data: tmp_znx_pt,
|
||||
basek: log_base2k,
|
||||
basek: basek,
|
||||
k: k,
|
||||
};
|
||||
|
||||
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
|
||||
data: tmp_znx_ct,
|
||||
basek: log_base2k,
|
||||
basek: basek,
|
||||
k,
|
||||
};
|
||||
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
|
||||
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2);
|
||||
|
||||
(0..cols).for_each(|col_i| {
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
@@ -193,30 +182,6 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_vec_glwe(module, self, lhs, scratch);
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_vec_glwe_inplace(module, self, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
@@ -227,7 +192,55 @@ where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_vec_glwe(module, self, lhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
lhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||
self.rank(),
|
||||
lhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank(),
|
||||
"ggsw_in rank: {} != ggsw_apply rank: {}",
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank() + 1).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank() + 1).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
@@ -238,7 +251,32 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_vec_glwe_inplace(module, self, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_apply: {}",
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank() + 1).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.external_product_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,73 +308,3 @@ where
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecGLWEProductScratchSpace for GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
fn prod_with_glwe_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
rgsw_size: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx_dft(rank_out + 1, rgsw_size)
|
||||
+ ((module.bytes_of_vec_znx_dft(rank_in + 1, a_size)
|
||||
+ module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
a_size,
|
||||
rank_in + 1,
|
||||
rank_out + 1,
|
||||
rgsw_size,
|
||||
))
|
||||
| module.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecGLWEProduct for GGSWCiphertext<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
fn prod_with_glwe<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<R>,
|
||||
a: &GLWECiphertext<A>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
let log_base2k: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), a.rank());
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
assert_eq!(res.basek(), log_base2k);
|
||||
assert_eq!(a.basek(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, a.size());
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut a_dft, col_i, a, col_i);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(log_base2k, res, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user