mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Added tensor key & associated test
This commit is contained in:
@@ -6,6 +6,7 @@ use base2k::{
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
@@ -78,6 +79,20 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
auto_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let size: usize = in_size.min(out_size);
|
||||
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size);
|
||||
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size);
|
||||
tmp_dft + tmp_idft + vmp
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
@@ -182,6 +197,73 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
pub fn automorphism<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
lhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||
self.rank(),
|
||||
lhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank(),
|
||||
"ggsw_in rank: {} != auto_key rank: {}",
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let size: usize = self.size().min(lhs.size());
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size);
|
||||
|
||||
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_dft_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size);
|
||||
|
||||
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: tmp_idft_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..cols).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
|
||||
tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2);
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
|
||||
});
|
||||
self.set_row(module, row_j, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_dft.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank() + 1).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_dft);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
|
||||
Reference in New Issue
Block a user