Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -3,7 +3,7 @@ use poulpy_hal::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch},
source::Source,
@@ -12,23 +12,24 @@ use poulpy_hal::{
use crate::{
TakeGLWESecret, TakeGLWESecretPrepared,
layouts::{
GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos,
Degree, GGLWELayoutInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank,
prepared::{GLWESecretPrepared, Prepare},
},
};
impl GGLWETensorKey<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GGLWELayoutInfos,
Module<B>:
SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
{
GLWESecretPrepared::bytes_of(module, rank)
+ module.vec_znx_dft_alloc_bytes(rank, 1)
GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out())
+ module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1)
+ module.vec_znx_big_alloc_bytes(1, 1)
+ module.vec_znx_dft_alloc_bytes(1, 1)
+ GLWESecret::bytes_of(module.n(), 1)
+ GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank)
+ GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1))
+ GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos)
}
}
@@ -51,7 +52,7 @@ impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -65,36 +66,36 @@ impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk.rank());
assert_eq!(self.rank_out(), sk.rank());
assert_eq!(self.n(), sk.n());
}
let n: usize = sk.n();
let rank: usize = self.rank();
let n: Degree = sk.n();
let rank: Rank = self.rank_out();
let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
sk_dft_prep.prepare(module, sk, scratch_1);
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1);
(0..rank).for_each(|i| {
(0..rank.into()).for_each(|i| {
module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
});
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1);
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1));
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1);
(0..rank).for_each(|i| {
(i..rank).for_each(|j| {
(0..rank.into()).for_each(|i| {
(i..rank.into()).for_each(|j| {
module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
module.vec_znx_big_normalize(
self.basek(),
self.base2k().into(),
&mut sk_ij.data.as_vec_znx_mut(),
0,
self.base2k().into(),
&sk_ij_big,
0,
scratch_5,