mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -1,22 +1,20 @@
|
||||
use base2k::{
|
||||
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
|
||||
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
|
||||
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
SIX_SIGMA,
|
||||
elem::Infos,
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
utils::derive_size,
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
};
|
||||
|
||||
pub struct GLWECiphertext<C> {
|
||||
@@ -115,33 +113,50 @@ impl GLWECiphertext<Vec<u8>> {
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
||||
)
|
||||
module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size)
|
||||
+ (module.vec_znx_big_normalize_tmp_bytes()
|
||||
| (module.vmp_apply_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size,
|
||||
in_rank + 1,
|
||||
out_rank + 1,
|
||||
ksk_size,
|
||||
) + module.bytes_of_vec_znx_dft(in_size, in_size)))
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank, rank,
|
||||
)
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx_dft(rank + 1, ggsw_size)
|
||||
+ ((module.bytes_of_vec_znx_dft(rank + 1, in_size)
|
||||
+ module.vmp_apply_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
rank + 1, // cols in
|
||||
rank + 1, // cols out
|
||||
ggsw_size,
|
||||
))
|
||||
| module.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,7 +250,50 @@ where
|
||||
VecZnx<DataLhs>: VecZnxToRef,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe(module, self, lhs, scratch);
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
@@ -246,7 +304,10 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe_inplace(module, self, scratch);
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
@@ -259,7 +320,36 @@ where
|
||||
VecZnx<DataLhs>: VecZnxToRef,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe(module, self, lhs, scratch);
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(rhs.rank(), lhs.rank());
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
@@ -270,7 +360,10 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe_inplace(module, self, scratch);
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(
|
||||
|
||||
Reference in New Issue
Block a user