mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -1,20 +1,13 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
|
||||
VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
|
||||
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
utils::derive_size,
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GLWECiphertextFourier<C, B: Backend> {
|
||||
@@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
|
||||
}
|
||||
|
||||
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
|
||||
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)),
|
||||
basek: log_base2k,
|
||||
k: log_k,
|
||||
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k: k,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,33 +85,56 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
|
||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
||||
)
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||
|
||||
let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size)
|
||||
+ module.vmp_apply_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size,
|
||||
in_rank + 1,
|
||||
out_rank + 1,
|
||||
ksk_size,
|
||||
);
|
||||
let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size);
|
||||
let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes();
|
||||
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
res_dft + (vmp | add_a0 | (res_small + normalize))
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
|
||||
module, res_size, lhs, rhs, rank, rank,
|
||||
)
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
|
||||
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
res_dft + (vmp | (res_small + normalize))
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +174,61 @@ where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch);
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Buffer of the result of VMP in DFT
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
// Applies VMP
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
// Switches result of VMP outside of DFT
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
{
|
||||
// Switches lhs 0-th outside of DFT domain and adds on
|
||||
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
||||
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
||||
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
||||
}
|
||||
|
||||
// Space fr normalized VMP result outside of DFT domain
|
||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||
module.vec_znx_dft(self, i, &res_small, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
@@ -169,7 +239,10 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch);
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
@@ -182,7 +255,37 @@ where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe_fourier(module, self, lhs, scratch);
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(rhs.rank(), lhs.rank());
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
|
||||
// Space for VMP result in DFT domain and high precision
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
||||
|
||||
{
|
||||
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
|
||||
}
|
||||
|
||||
// VMP result in high precision
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
// Space for VMP result normalized
|
||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||
module.vec_znx_dft(self, i, &res_small, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
@@ -193,7 +296,10 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe_fourier_inplace(module, self, scratch);
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +353,7 @@ where
|
||||
pt.k = pt.k().min(self.k());
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
||||
where
|
||||
GLWECiphertext<DataRes>: VecZnxToMut,
|
||||
|
||||
Reference in New Issue
Block a user