refactor of key-switching & external product

This commit is contained in:
Jean-Philippe Bossuat
2025-05-15 18:24:56 +02:00
parent 723a41acd0
commit ccd7450c5f
15 changed files with 1593 additions and 1740 deletions

View File

@@ -1,20 +1,13 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::Infos,
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
};
pub struct GLWECiphertextFourier<C, B: Backend> {
@@ -24,11 +17,11 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
}
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)),
basek: log_base2k,
k: log_k,
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
basek: basek,
k: k,
}
}
}
@@ -92,33 +85,56 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let vmp = module.bytes_of_vec_znx_dft(in_rank, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size,
in_rank + 1,
out_rank + 1,
ksk_size,
);
let res_small: usize = module.bytes_of_vec_znx(out_rank + 1, out_size);
let add_a0: usize = module.bytes_of_vec_znx_big(1, in_size) + module.vec_znx_idft_tmp_bytes();
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | add_a0 | (res_small + normalize))
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | (res_small + normalize))
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_fourier_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
}
}
@@ -158,7 +174,61 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
// Space fr normalized VMP result outside of DFT domain
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
});
}
pub fn keyswitch_inplace<DataRhs>(
@@ -169,7 +239,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
@@ -182,7 +255,37 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_fourier(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
// Space for VMP result in DFT domain and high precision
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
}
// VMP result in high precision
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
// Space for VMP result normalized
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
});
}
pub fn external_product_inplace<DataRhs>(
@@ -193,7 +296,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_fourier_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
}
@@ -247,6 +353,7 @@ where
pt.k = pt.k().min(self.k());
}
#[allow(dead_code)]
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
where
GLWECiphertext<DataRes>: VecZnxToMut,