mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
|
||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
|
||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -10,7 +10,6 @@ use crate::{
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
};
|
||||
|
||||
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
|
||||
@@ -39,6 +38,20 @@ impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GLWESwitchingKey<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.data.cols_out() - 1
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.0.data.cols_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.0.data.cols_out() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
|
||||
where
|
||||
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
|
||||
@@ -131,33 +144,46 @@ where
|
||||
impl GLWESwitchingKey<Vec<u8>, FFT64> {
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank_in, rank_out,
|
||||
)
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||
let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size);
|
||||
tmp_in + tmp_out + ksk
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||
let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size);
|
||||
tmp + ksk
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
|
||||
module, res_size, lhs, rhs, rank, rank,
|
||||
)
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs, rank,
|
||||
)
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,8 +201,62 @@ where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0
|
||||
.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.keyswitch(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
@@ -187,8 +267,32 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0
|
||||
.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.keyswitch_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
@@ -201,7 +305,62 @@ where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_in output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_out output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
@@ -212,6 +371,31 @@ where
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_out output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.external_product_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user