Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -3,7 +3,7 @@ use poulpy_hal::{
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
source::Source,
@@ -13,26 +13,30 @@ use crate::{
dist::Distribution,
encryption::{SIGMA, SIGMA_BOUND},
layouts::{
GLWECiphertext, GLWEPlaintext, Infos,
GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos,
prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
},
};
impl GLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GLWEInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
{
let size: usize = k.div_ceil(basek);
let size: usize = infos.size();
assert_eq!(module.n() as u32, infos.n());
module.vec_znx_normalize_tmp_bytes()
+ 2 * VecZnx::alloc_bytes(module.n(), 1, size)
+ module.vec_znx_dft_alloc_bytes(1, size)
}
pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GLWEInfos,
Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
{
let size: usize = k.div_ceil(basek);
let size: usize = infos.size();
assert_eq!(module.n() as u32, infos.n());
((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size))
| ScalarZnx::alloc_bytes(module.n(), 1))
+ module.svp_ppol_alloc_bytes(1)
@@ -58,7 +62,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -72,10 +76,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
assert_eq!(sk.n(), self.n());
assert_eq!(pt.n(), self.n());
assert!(
scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
"scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
scratch.available(),
GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
GLWECiphertext::encrypt_sk_scratch_space(module, self)
)
}
@@ -97,7 +101,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -110,10 +114,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
assert_eq!(self.rank(), sk.rank());
assert_eq!(sk.n(), self.n());
assert!(
scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
"scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
scratch.available(),
GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
GLWECiphertext::encrypt_sk_scratch_space(module, self)
)
}
self.encrypt_sk_internal(
@@ -143,7 +147,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -151,11 +155,11 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VecZnxSub,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
let cols: usize = self.rank() + 1;
let cols: usize = (self.rank() + 1).into();
glwe_encrypt_sk_internal(
module,
self.basek(),
self.k(),
self.base2k().into(),
self.k().into(),
&mut self.data,
cols,
false,
@@ -235,24 +239,24 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
{
#[cfg(debug_assertions)]
{
assert_eq!(self.basek(), pk.basek());
assert_eq!(self.base2k(), pk.base2k());
assert_eq!(self.n(), pk.n());
assert_eq!(self.rank(), pk.rank());
if let Some((pt, _)) = pt {
assert_eq!(pt.basek(), pk.basek());
assert_eq!(pt.base2k(), pk.base2k());
assert_eq!(pt.n(), pk.n());
}
}
let basek: usize = pk.basek();
let base2k: usize = pk.base2k().into();
let size_pk: usize = pk.size();
let cols: usize = self.rank() + 1;
let cols: usize = (self.rank() + 1).into();
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1);
let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1);
{
let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1);
match pk.dist {
Distribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
@@ -271,7 +275,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
// ct[i] = pk[i] * u + ei (+ m if col = i)
(0..cols).for_each(|i| {
let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk);
// ci_dft = DFT(u) * DFT(pk[i])
module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
@@ -279,7 +283,15 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
// ci_big = u * pk[i] + e
module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND);
module.vec_znx_big_add_normal(
base2k,
&mut ci_big,
0,
pk.k().into(),
source_xe,
SIGMA,
SIGMA_BOUND,
);
// ci_big = u * pk[i] + e + m (if col = i)
if let Some((pt, col)) = pt
@@ -289,7 +301,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
}
// ct[i] = norm(ci_big)
module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2);
});
}
}
@@ -297,7 +309,7 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
module: &Module<B>,
basek: usize,
base2k: usize,
k: usize,
ct: &mut VecZnx<DataCt>,
cols: usize,
@@ -316,7 +328,7 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -350,7 +362,7 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
let col_ct: usize = if compressed { 0 } else { i };
// ct[i] = uniform (+ pt)
module.vec_znx_fill_uniform(basek, ct, col_ct, source_xa);
module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa);
let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
@@ -360,7 +372,7 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
if let Some((pt, col)) = pt {
if i == col {
module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3);
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
} else {
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
@@ -373,15 +385,15 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
// use c[0] as buffer, which is overwritten later by the normalization step
module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3);
// c0_tmp = -c[i] * s[i] (use c[0] as buffer)
module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0);
module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
});
}
// c[0] += e
module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
// c[0] += m if col = 0
if let Some((pt, col)) = pt
@@ -391,5 +403,5 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
}
// c[0] = norm(c[0])
module.vec_znx_normalize(basek, ct, 0, &c0, 0, scratch_1);
module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1);
}