Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -9,12 +9,13 @@ use crate::{
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
znx_add_normal_f64_ref,
},
},
source::Source,
@@ -230,20 +231,32 @@ where
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_big_normalize<R, A, BE>(
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxBigToRef<BE>,
BE: Backend<ScalarBig = i64>
+ ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxZero,
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
@@ -254,11 +267,11 @@ where
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -275,8 +288,8 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -291,7 +304,7 @@ where
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -303,15 +316,15 @@ where
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
@@ -363,9 +376,9 @@ where
}
/// R <- A - B
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -388,13 +401,13 @@ where
max_size: a.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- B - A
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -417,7 +430,7 @@ where
max_size: a.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- A - B
@@ -483,7 +496,7 @@ where
/// R <- R - A
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -497,13 +510,13 @@ where
max_size: res.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
/// R <- A - R
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -517,5 +530,5 @@ where
max_size: res.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}