Added tensor key & associated test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-19 18:06:14 +02:00
parent c5fe07188f
commit 8f2eac4928
12 changed files with 610 additions and 28 deletions

View File

@@ -6,6 +6,7 @@ use base2k::{
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
@@ -78,6 +79,20 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
auto_key_size: usize,
rank: usize,
) -> usize {
let size: usize = in_size.min(out_size);
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size);
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size);
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size);
tmp_dft + tmp_idft + vmp
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
@@ -182,6 +197,73 @@ where
});
}
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_in rank: {} != auto_key rank: {}",
self.rank(),
rhs.rank()
);
}
let size: usize = self.size().min(lhs.size());
let cols: usize = self.rank() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size);
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size);
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: tmp_idft_data,
basek: self.basek(),
k: self.k(),
};
(0..cols).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2);
(0..cols).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
});
self.set_row(module, row_j, col_i, &tmp_dft);
});
});
tmp_dft.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank() + 1).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_dft);
});
});
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,