mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -37,7 +37,7 @@ pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
pub fn reim_sub_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
@@ -49,7 +49,7 @@ pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
pub fn reim_sub_negate_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
|
||||
@@ -91,12 +91,12 @@ pub trait ReimSub {
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubABInplace {
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
|
||||
pub trait ReimSubInplace {
|
||||
fn reim_sub_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubBAInplace {
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
|
||||
pub trait ReimSubNegateInplace {
|
||||
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegate {
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
|
||||
|
||||
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::from(1. / 4.).unwrap();
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
|
||||
|
||||
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::exp2(R::from(-2).unwrap());
|
||||
|
||||
@@ -9,12 +9,13 @@ use crate::{
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
|
||||
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
|
||||
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
|
||||
znx_add_normal_f64_ref,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
@@ -230,20 +231,32 @@ where
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
2 * n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
BE: Backend<ScalarBig = i64>
|
||||
+ ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxAddInplace
|
||||
+ ZnxMulPowerOfTwoInplace
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxZero,
|
||||
+ ZnxExtractDigitAddMul
|
||||
+ ZnxNormalizeDigit,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
@@ -254,11 +267,11 @@ where
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
|
||||
vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -275,8 +288,8 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
@@ -291,7 +304,7 @@ where
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let base2k: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
@@ -303,15 +316,15 @@ where
|
||||
let sqrt2: f64 = SQRT_2;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
let std: f64 = a.std(base2k, col_i) * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={} ~!= {}",
|
||||
@@ -363,9 +376,9 @@ where
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_big_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
@@ -388,13 +401,13 @@ where
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- B - A
|
||||
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_big_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
@@ -417,7 +430,7 @@ where
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
@@ -483,7 +496,7 @@ where
|
||||
/// R <- R - A
|
||||
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -497,13 +510,13 @@ where
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - R
|
||||
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -517,5 +530,5 @@ where
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
reference::{
|
||||
fft64::reim::{
|
||||
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
|
||||
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
|
||||
ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
|
||||
},
|
||||
znx::ZnxZero,
|
||||
},
|
||||
@@ -308,9 +308,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_dft_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
@@ -328,13 +328,13 @@ where
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
BE::reim_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_dft_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubNegateInplace + ReimNegateInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
@@ -352,7 +352,7 @@ where
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
BE::reim_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
|
||||
Reference in New Issue
Block a user