refactor of key-switching & external product

This commit is contained in:
Jean-Philippe Bossuat
2025-05-15 18:24:56 +02:00
parent 723a41acd0
commit ccd7450c5f
15 changed files with 1593 additions and 1740 deletions

View File

@@ -1,22 +1,20 @@
use base2k::{
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
use crate::{
SIX_SIGMA,
elem::Infos,
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GLWECiphertext<C> {
@@ -115,33 +113,50 @@ impl GLWECiphertext<Vec<u8>> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
res_size: usize,
lhs: usize,
rhs: usize,
rank_in: usize,
rank_out: usize,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
module, res_size, lhs, rhs, rank_in, rank_out,
)
module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size,
in_rank + 1,
out_rank + 1,
ksk_size,
) + module.bytes_of_vec_znx_dft(in_size, in_size)))
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(
module, res_size, lhs, rhs, rank, rank,
)
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, ggsw_size)
+ ((module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size, // rows
rank + 1, // cols in
rank + 1, // cols out
ggsw_size,
))
| module.vec_znx_big_normalize_tmp_bytes())
}
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize, rank: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs, rank,
)
GLWECiphertext::external_product_scratch_space(module, res_size, res_size, rhs, rank)
}
}
@@ -235,7 +250,50 @@ where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertext::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch_inplace<DataRhs>(
@@ -246,7 +304,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0.prod_with_glwe_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
@@ -259,7 +320,36 @@ where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe(module, self, lhs, scratch);
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
});
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn external_product_inplace<DataRhs>(
@@ -270,7 +360,10 @@ where
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_glwe_inplace(module, self, scratch);
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(