Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -8,11 +8,11 @@ use poulpy_hal::{
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
};
use crate::layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
use crate::layouts::{GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared};
impl<D: DataRef> GGLWECiphertext<D> {
pub fn assert_noise<B, DataSk, DataWant>(
self,
&self,
module: &Module<B>,
sk: &GLWESecretPrepared<DataSk, B>,
pt_want: &ScalarZnx<DataWant>,
@@ -32,15 +32,14 @@ impl<D: DataRef> GGLWECiphertext<D> {
+ VecZnxSubScalarInplace,
B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
{
let digits: usize = self.digits();
let basek: usize = self.basek();
let k: usize = self.k();
let digits: usize = self.digits().into();
let base2k: usize = self.base2k().into();
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k));
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self));
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_i| {
(0..self.rank_in().into()).for_each(|col_i| {
(0..self.rows().into()).for_each(|row_i| {
self.at(row_i, col_i)
.decrypt(module, &mut pt, sk, scratch.borrow());
@@ -52,13 +51,13 @@ impl<D: DataRef> GGLWECiphertext<D> {
col_i,
);
let noise_have: f64 = pt.data.std(basek, 0).log2();
let noise_have: f64 = pt.data.std(base2k, 0).log2();
println!("noise_have: {noise_have}");
assert!(
noise_have <= max_noise,
"noise_have: {} > max_noise: {}",
noise_have,
max_noise
"noise_have: {noise_have} > max_noise: {max_noise}"
);
pt.data.zero();

View File

@@ -3,13 +3,15 @@ use poulpy_hal::{
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
VecZnxNormalizeTmpBytes, VecZnxSubABInplace,
VecZnxNormalizeTmpBytes, VecZnxSubInplace,
},
layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
};
use crate::layouts::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
use crate::layouts::{
GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared,
};
impl<D: DataRef> GGSWCiphertext<D> {
pub fn assert_noise<B, DataSk, DataScalar, F>(
@@ -35,24 +37,23 @@ impl<D: DataRef> GGSWCiphertext<D> {
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<B>
+ VecZnxAddScalarInplace
+ VecZnxSubABInplace,
+ VecZnxSubInplace,
B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
F: Fn(usize) -> f64,
{
let basek: usize = self.basek();
let k: usize = self.k();
let digits: usize = self.digits();
let base2k: usize = self.base2k().into();
let digits: usize = self.digits().into();
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
let mut scratch: ScratchOwned<B> =
ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes());
(0..self.rank() + 1).for_each(|col_j| {
(0..self.rows()).for_each(|row_i| {
(0..(self.rank() + 1).into()).for_each(|col_j| {
(0..self.rows().into()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
// mul with sk[col_j-1]
@@ -60,17 +61,25 @@ impl<D: DataRef> GGSWCiphertext<D> {
module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
module.vec_znx_big_normalize(
base2k,
&mut pt.data,
0,
base2k,
&pt_big,
0,
scratch.borrow(),
);
}
self.at(row_i, col_j)
.decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(basek, 0).log2();
let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
let noise: f64 = max_noise(col_j);
assert!(std_pt <= noise, "{} > {}", std_pt, noise);
assert!(std_pt <= noise, "{std_pt} > {noise}");
pt.data.zero();
});
@@ -101,23 +110,22 @@ impl<D: DataRef> GGSWCiphertext<D> {
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<B>
+ VecZnxAddScalarInplace
+ VecZnxSubABInplace,
+ VecZnxSubInplace,
B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
{
let basek: usize = self.basek();
let k: usize = self.k();
let digits: usize = self.digits();
let base2k: usize = self.base2k().into();
let digits: usize = self.digits().into();
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
let mut scratch: ScratchOwned<B> =
ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes());
(0..self.rank() + 1).for_each(|col_j| {
(0..self.rows()).for_each(|row_i| {
(0..(self.rank() + 1).into()).for_each(|col_j| {
(0..self.rows().into()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
// mul with sk[col_j-1]
@@ -125,16 +133,24 @@ impl<D: DataRef> GGSWCiphertext<D> {
module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
module.vec_znx_big_normalize(
base2k,
&mut pt.data,
0,
base2k,
&pt_big,
0,
scratch.borrow(),
);
}
self.at(row_i, col_j)
.decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(basek, 0).log2();
println!("col: {} row: {}: {}", col_j, row_i, std_pt);
let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
println!("col: {col_j} row: {row_i}: {std_pt}");
pt.data.zero();
});
});

View File

@@ -2,17 +2,13 @@ use poulpy_hal::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubABInplace,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubInplace,
},
layouts::{Backend, DataRef, Module, ScratchOwned},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
};
use crate::{
layouts::GLWEPlaintext,
layouts::prepared::GLWESecretPrepared,
layouts::{GLWECiphertext, Infos},
};
use crate::layouts::{GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared};
impl<D: DataRef> GLWECiphertext<D> {
pub fn assert_noise<B, DataSk, DataPt>(
@@ -33,24 +29,20 @@ impl<D: DataRef> GLWECiphertext<D> {
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxNormalizeInplace<B>,
B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
{
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), self.basek(), self.k());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(
module,
self.basek(),
self.k(),
));
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self));
self.decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
module.vec_znx_normalize_inplace(self.basek(), &mut pt_have.data, 0, scratch.borrow());
module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
module.vec_znx_normalize_inplace(self.base2k().into(), &mut pt_have.data, 0, scratch.borrow());
let noise_have: f64 = pt_have.data.std(self.basek(), 0).log2();
assert!(noise_have <= max_noise, "{} {}", noise_have, max_noise);
let noise_have: f64 = pt_have.data.std(self.base2k().into(), 0).log2();
assert!(noise_have <= max_noise, "{noise_have} {max_noise}");
}
}

View File

@@ -6,7 +6,7 @@ mod glwe_ct;
#[allow(dead_code)]
pub(crate) fn var_noise_gglwe_product(
n: f64,
basek: usize,
base2k: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
@@ -17,12 +17,12 @@ pub(crate) fn var_noise_gglwe_product(
b_logq: usize,
) -> f64 {
let a_logq: usize = a_logq.min(b_logq);
let a_cols: usize = a_logq.div_ceil(basek);
let a_cols: usize = a_logq.div_ceil(base2k);
let b_scale: f64 = (b_logq as f64).exp2();
let a_scale: f64 = ((b_logq - a_logq) as f64).exp2();
let base: f64 = (basek as f64).exp2();
let base: f64 = (base2k as f64).exp2();
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
@@ -38,7 +38,7 @@ pub(crate) fn var_noise_gglwe_product(
#[allow(dead_code)]
pub(crate) fn log2_std_noise_gglwe_product(
n: f64,
basek: usize,
base2k: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
@@ -50,7 +50,7 @@ pub(crate) fn log2_std_noise_gglwe_product(
) -> f64 {
let mut noise: f64 = var_noise_gglwe_product(
n,
basek,
base2k,
var_xs,
var_msg,
var_a_err,
@@ -68,7 +68,7 @@ pub(crate) fn log2_std_noise_gglwe_product(
#[allow(dead_code)]
pub(crate) fn noise_ggsw_product(
n: f64,
basek: usize,
base2k: usize,
var_xs: f64,
var_msg: f64,
var_a0_err: f64,
@@ -80,12 +80,12 @@ pub(crate) fn noise_ggsw_product(
k_ggsw: usize,
) -> f64 {
let a_logq: usize = k_in.min(k_ggsw);
let a_cols: usize = a_logq.div_ceil(basek);
let a_cols: usize = a_logq.div_ceil(base2k);
let b_scale: f64 = (k_ggsw as f64).exp2();
let a_scale: f64 = ((k_ggsw - a_logq) as f64).exp2();
let base: f64 = (basek as f64).exp2();
let base: f64 = (base2k as f64).exp2();
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
@@ -102,7 +102,7 @@ pub(crate) fn noise_ggsw_product(
#[allow(dead_code)]
pub(crate) fn noise_ggsw_keyswitch(
n: f64,
basek: usize,
base2k: usize,
col: usize,
var_xs: f64,
var_a_err: f64,
@@ -118,7 +118,7 @@ pub(crate) fn noise_ggsw_keyswitch(
// Initial KS for col = 0
let mut noise: f64 = var_noise_gglwe_product(
n,
basek,
base2k,
var_xs,
var_xs,
var_a_err,
@@ -133,7 +133,7 @@ pub(crate) fn noise_ggsw_keyswitch(
if col > 0 {
noise += var_noise_gglwe_product(
n,
basek,
base2k,
var_xs,
var_si_x_sj,
var_a_err + 1f64 / 12.0,