Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -37,7 +37,7 @@ pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
}
#[inline(always)]
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
@@ -49,7 +49,7 @@ pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
}
#[inline(always)]
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_negate_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());

View File

@@ -91,12 +91,12 @@ pub trait ReimSub {
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimSubABInplace {
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
pub trait ReimSubInplace {
fn reim_sub_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSubBAInplace {
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
pub trait ReimSubNegateInplace {
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegate {

View File

@@ -22,7 +22,7 @@ pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::from(1. / 4.).unwrap();

View File

@@ -22,7 +22,7 @@ pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::exp2(R::from(-2).unwrap());

View File

@@ -9,12 +9,13 @@ use crate::{
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
znx_add_normal_f64_ref,
},
},
source::Source,
@@ -230,20 +231,32 @@ where
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_big_normalize<R, A, BE>(
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxBigToRef<BE>,
BE: Backend<ScalarBig = i64>
+ ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxZero,
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
@@ -254,11 +267,11 @@ where
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -275,8 +288,8 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -291,7 +304,7 @@ where
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -303,15 +316,15 @@ where
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
@@ -363,9 +376,9 @@ where
}
/// R <- A - B
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -388,13 +401,13 @@ where
max_size: a.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- B - A
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -417,7 +430,7 @@ where
max_size: a.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- A - B
@@ -483,7 +496,7 @@ where
/// R <- R - A
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -497,13 +510,13 @@ where
max_size: res.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
/// R <- A - R
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -517,5 +530,5 @@ where
max_size: res.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}

View File

@@ -8,7 +8,7 @@ use crate::{
reference::{
fft64::reim::{
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
},
znx::ZnxZero,
},
@@ -308,9 +308,9 @@ where
}
}
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_dft_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
BE: Backend<ScalarPrep = f64> + ReimSubInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
@@ -328,13 +328,13 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
BE::reim_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_dft_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
BE: Backend<ScalarPrep = f64> + ReimSubNegateInplace + ReimNegateInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
@@ -352,7 +352,7 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
BE::reim_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {

View File

@@ -91,7 +91,7 @@ pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add::{}", label);
let group_name: String = format!("vec_znx_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -136,7 +136,7 @@ pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add_inplace::{}", label);
let group_name: String = format!("vec_znx_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -18,12 +18,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {

View File

@@ -63,7 +63,7 @@ pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_automorphism::{}", label);
let group_name: String = format!("vec_znx_automorphism::{label}");
let mut group = c.benchmark_group(group_name);
@@ -108,7 +108,7 @@ where
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -9,8 +9,8 @@ use crate::{
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
vec_znx::{vec_znx_rotate, vec_znx_sub_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
},
source::Source,
};
@@ -23,16 +23,16 @@ pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usiz
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubInplace,
{
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
}
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubNegateInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
@@ -41,7 +41,7 @@ where
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), tmp);
}
}
@@ -49,7 +49,7 @@ pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one::{label}");
let mut group = c.benchmark_group(group_name);
@@ -94,7 +94,7 @@ where
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -49,7 +49,7 @@ pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate::{}", label);
let group_name: String = format!("vec_znx_negate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -93,7 +93,7 @@ pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
let group_name: String = format!("vec_znx_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -6,71 +6,204 @@ use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_normalize<R, A, ZNXARI>(
res_base2k: usize,
res: &mut R,
res_col: usize,
a_base2k: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep,
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
assert!(carry.len() >= 2 * res.n());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size = a.size();
let a_size: usize = a.size();
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
if res_base2k == a_base2k {
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
let (a_norm, carry) = carry.split_at_mut(n);
// Relevant limbs of res
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
// Relevant limbs of a
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
// Get carry for limbs of a that have higher precision than res
for j in (a_min_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
if a_min_size == a_size {
ZNXARI::znx_zero(carry);
}
// Maximum relevant precision of a
let a_prec: usize = a_min_size * a_base2k;
// Maximum relevant precision of res
let res_prec: usize = res_min_size * res_base2k;
// Res limb index
let mut res_idx: usize = res_min_size - 1;
// Trackers: wow much of res is left to be populated
// for the current limb.
let mut res_left: usize = res_base2k;
for j in (0..a_min_size).rev() {
// Trackers: wow much of a_norm is left to
// be flushed on res.
let mut a_left: usize = a_base2k;
// Normalizes the j-th limb of a and store the results into a_norm.
// This step is required to avoid overflow in the next step,
// which assumes that |a| is bounded by 2^{a_base2k -1}.
if j != 0 {
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
}
// In the first iteration we need to match the precision of the input/output.
// If a_min_size * a_base2k > res_min_size * res_base2k
// then divround a_norm by the difference of precision and
// acts like if a_norm has already been partially consummed.
// Else acts like if res has been already populated
// by the difference.
if j == a_min_size - 1 {
if a_prec > res_prec {
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
a_left -= a_prec - res_prec;
} else if res_prec > a_prec {
res_left -= res_prec - a_prec;
}
}
// Flushes a into res
loop {
// Selects the maximum amount of a that can be flushed
let a_take: usize = a_base2k.min(a_left).min(res_left);
// Output limb
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
// Scaling of the value to flush
let lsh: usize = res_base2k - res_left;
// Extract the bits to flush on the output and updates
// a_norm accordingly.
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
// Updates the trackers
a_left -= a_take;
res_left -= a_take;
// If the current limb of res is full,
// then normalizes this limb and adds
// the carry on a_norm.
if res_left == 0 {
// Updates tracker
res_left += res_base2k;
// Normalizes res and propagates the carry on a.
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
// If we reached the last limb of res breaks,
// but we might rerun the above loop if the
// base2k of a is much smaller than the base2k
// of res.
if res_idx == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
// Else updates the limb index of res.
res_idx -= 1
}
// If a_norm is exhausted, breaks the loop.
if a_left == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
}
}
for j in res_min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
@@ -85,11 +218,11 @@ where
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
}
}
}
@@ -99,7 +232,7 @@ where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{}", label);
let group_name: String = format!("vec_znx_normalize::{label}");
let mut group = c.benchmark_group(group_name);
@@ -114,7 +247,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -129,7 +262,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
}
black_box(());
}
@@ -149,7 +282,7 @@ where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
let group_name: String = format!("vec_znx_normalize_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -164,7 +297,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -177,7 +310,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
}
black_box(());
}
@@ -191,3 +324,83 @@ where
group.finish();
}
#[test]
fn test_vec_znx_normalize_conv() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; 2 * n];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let mut data: Vec<i128> = vec![0i128; n];
data.iter_mut().for_each(|x| *x = source.next_i128());
for start_base2k in 1..50 {
for end_base2k in 1..50 {
let end_size: usize = prec.div_ceil(end_base2k);
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
want.encode_vec_i128(end_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
// Creates a temporary poly where encoding is in start_base2k
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
let out_prec: u32 = (end_size * end_base2k) as u32;
let mut data_want: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
let mut data_res: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
have.decode_vec_float(end_base2k, 0, &mut data_want);
want.decode_vec_float(end_base2k, 0, &mut data_res);
for i in 0..n {
let mut err: Float = data_want[i].clone();
err.sub_assign_round(&data_res[i], Round::Nearest);
err = err.abs();
// println!(
// "want: {} have: {} tmp: {} (want-have): {}",
// data_want[i].to_f64(),
// data_res[i].to_f64(),
// data_tmp[i].to_f64(),
// err.to_f64()
// );
let err_log2: f64 = err
.clone()
.max(&Float::with_val(prec as u32, 1e-60))
.log2()
.to_f64();
assert!(
err_log2 <= -(out_prec as f64) + 1.,
"{} {}",
err_log2,
-(out_prec as f64) + 1.
)
}
}
}
}

View File

@@ -61,7 +61,7 @@ pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_rotate::{}", label);
let group_name: String = format!("vec_znx_rotate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -106,7 +106,7 @@ where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
let group_name: String = format!("vec_znx_rotate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -4,18 +4,18 @@ use crate::{
source::Source,
};
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
pub fn vec_znx_fill_uniform_ref<R>(base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
znx_fill_uniform_ref(base2k, res.at_mut(res_col, j), source)
}
}
pub fn vec_znx_fill_normal_ref<R>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -32,8 +32,8 @@ pub fn vec_znx_fill_normal_ref<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -42,8 +42,15 @@ pub fn vec_znx_fill_normal_ref<R>(
)
}
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
where
pub fn vec_znx_add_normal_ref<R>(
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -53,8 +60,8 @@ where
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,

View File

@@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -35,8 +35,8 @@ where
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= size {
for j in 0..size {
@@ -45,7 +45,7 @@ where
return;
}
// Inplace shift of limbs by a k/basek
// Inplace shift of limbs by a k/base2k
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
@@ -65,21 +65,21 @@ where
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
// Inplace normalization with left shift of k % base2k
if !k.is_multiple_of(base2k) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -90,8 +90,8 @@ where
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
@@ -103,12 +103,12 @@ where
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
// by k/base2k and intra-limb by base2k - k%base2k
if !k.is_multiple_of(base2k) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -116,7 +116,7 @@ where
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -124,7 +124,7 @@ where
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -133,7 +133,7 @@ where
}
}
} else {
// If k % basek = 0, then this is simply a copy.
// If k % base2k = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
@@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -166,8 +166,8 @@ where
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
return;
@@ -184,8 +184,8 @@ where
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -194,9 +194,9 @@ where
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
}
});
@@ -206,20 +206,20 @@ where
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
// Shift by multiples of base2k
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
@@ -236,7 +236,7 @@ where
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -256,8 +256,8 @@ where
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
@@ -271,8 +271,8 @@ where
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -281,9 +281,9 @@ where
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
}
}
@@ -300,16 +300,16 @@ where
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
@@ -321,9 +321,9 @@ where
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
@@ -351,7 +351,7 @@ where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let group_name: String = format!("vec_znx_lsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -366,7 +366,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -381,7 +381,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -401,7 +401,7 @@ where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let group_name: String = format!("vec_znx_lsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -416,7 +416,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -431,7 +431,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -451,7 +451,7 @@ where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let group_name: String = format!("vec_znx_rsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -466,7 +466,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -481,7 +481,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -501,7 +501,7 @@ where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let group_name: String = format!("vec_znx_rsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -516,7 +516,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -531,7 +531,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -553,7 +553,7 @@ mod tests {
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
vec_znx_sub_inplace,
},
znx::ZnxRef,
},
@@ -574,20 +574,20 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
@@ -606,7 +606,7 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -615,29 +615,29 @@ mod tests {
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
if a_size + k.div_ceil(base2k) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
@@ -656,14 +656,14 @@ mod tests {
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
}
}
}

View File

@@ -3,10 +3,10 @@ use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
api::{ModuleNew, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
oep::{ModuleNewImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
source::Source,
};
@@ -64,11 +64,11 @@ where
}
}
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -84,15 +84,15 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_negate_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
ZNXARI: ZnxSubNegateInplace + ZnxNegateInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -108,7 +108,7 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
@@ -120,7 +120,7 @@ pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
{
let group_name: String = format!("vec_znx_sub::{}", label);
let group_name: String = format!("vec_znx_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -161,17 +161,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
Module<B>: VecZnxSubInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -190,7 +190,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
module.vec_znx_sub_inplace(&mut b, i, &a, i);
}
black_box(());
}
@@ -205,17 +205,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_negate_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubNegateInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
Module<B>: VecZnxSubNegateInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -234,7 +234,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
module.vec_znx_sub_negate_inplace(&mut b, i, &a, i);
}
black_box(());
}

View File

@@ -1,7 +1,7 @@
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
use crate::{
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
reference::znx::{ZnxSub, ZnxSubInplace, ZnxZero},
};
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
@@ -19,12 +19,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {
@@ -44,7 +39,7 @@ pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -54,5 +49,5 @@ where
assert!(res_limb < res.size());
}
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -9,7 +9,7 @@ pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn zn_normalize_inplace<R, ARI>(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: ZnToMut,
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
@@ -27,11 +27,11 @@ where
let out = &mut res.at_mut(res_col, j)[..n];
if j == res_size - 1 {
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry);
} else if j == 0 {
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry);
} else {
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry);
}
}
}
@@ -43,7 +43,7 @@ where
{
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let basek: usize = 12;
let base2k: usize = 12;
let n = 33;
@@ -63,8 +63,8 @@ where
// Reference
for i in 0..cols {
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow());
}
assert_eq!(res_0.raw(), res_1.raw());

View File

@@ -4,20 +4,20 @@ use crate::{
source::Source,
};
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
pub fn zn_fill_uniform<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source)
}
}
#[allow(clippy::too_many_arguments)]
pub fn zn_fill_normal<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -34,8 +34,8 @@ pub fn zn_fill_normal<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
@@ -47,7 +47,7 @@ pub fn zn_fill_normal<R>(
#[allow(clippy::too_many_arguments)]
pub fn zn_add_normal<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -64,8 +64,8 @@ pub fn zn_add_normal<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,

View File

@@ -1,8 +1,9 @@
use crate::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing,
ZnxZero,
add::{znx_add_inplace_ref, znx_add_ref},
automorphism::znx_automorphism_ref,
copy::znx_copy_ref,
@@ -12,9 +13,11 @@ use crate::reference::znx::{
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
},
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
sub::{znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref},
switch_ring::znx_switch_ring_ref,
zero::znx_zero_ref,
znx_extract_digit_addmul_ref, znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref,
znx_normalize_digit_ref,
};
pub struct ZnxRef {}
@@ -40,17 +43,17 @@ impl ZnxSub for ZnxRef {
}
}
impl ZnxSubABInplace for ZnxRef {
impl ZnxSubInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for ZnxRef {
impl ZnxSubNegateInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
@@ -61,6 +64,27 @@ impl ZnxAutomorphism for ZnxRef {
}
}
impl ZnxMulPowerOfTwo for ZnxRef {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulAddPowerOfTwo for ZnxRef {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for ZnxRef {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
impl ZnxCopy for ZnxRef {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
@@ -98,56 +122,70 @@ impl ZnxSwitchRing for ZnxRef {
impl ZnxNormalizeFinalStep for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for ZnxRef {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for ZnxRef {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}

View File

@@ -2,6 +2,7 @@ mod add;
mod arithmetic_ref;
mod automorphism;
mod copy;
mod mul;
mod neg;
mod normalization;
mod rotate;
@@ -14,6 +15,7 @@ pub use add::*;
pub use arithmetic_ref::*;
pub use automorphism::*;
pub use copy::*;
pub use mul::*;
pub use neg::*;
pub use normalization::*;
pub use rotate::*;
@@ -35,12 +37,12 @@ pub trait ZnxSub {
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxSubABInplace {
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
pub trait ZnxSubInplace {
fn znx_sub_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSubBAInplace {
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
pub trait ZnxSubNegateInplace {
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxAutomorphism {
@@ -67,38 +69,58 @@ pub trait ZnxZero {
fn znx_zero(res: &mut [i64]);
}
pub trait ZnxMulPowerOfTwo {
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxMulAddPowerOfTwo {
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxMulPowerOfTwoInplace {
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]);
}
pub trait ZnxSwitchRing {
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNormalizeFirstStepCarryOnly {
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStepInplace {
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStep {
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepCarryOnly {
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepInplace {
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStep {
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStepInplace {
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStep {
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxExtractDigitAddMul {
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]);
}
pub trait ZnxNormalizeDigit {
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]);
}

View File

@@ -0,0 +1,76 @@
use crate::reference::znx::{znx_add_inplace_ref, znx_copy_ref};
pub fn znx_mul_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
if k == 0 {
znx_copy_ref(res, a);
return;
}
if k > 0 {
for (y, x) in res.iter_mut().zip(a.iter()) {
*y = *x << k
}
return;
}
k = -k;
for (y, x) in res.iter_mut().zip(a.iter()) {
let sign_bit: i64 = (x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*y = (x + bias) >> k;
}
}
pub fn znx_mul_power_of_two_inplace_ref(mut k: i64, res: &mut [i64]) {
if k == 0 {
return;
}
if k > 0 {
for x in res.iter_mut() {
*x <<= k
}
return;
}
k = -k;
for x in res.iter_mut() {
let sign_bit: i64 = (*x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*x = (*x + bias) >> k;
}
}
pub fn znx_mul_add_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
if k == 0 {
znx_add_inplace_ref(res, a);
return;
}
if k > 0 {
for (y, x) in res.iter_mut().zip(a.iter()) {
*y += *x << k
}
return;
}
k = -k;
for (y, x) in res.iter_mut().zip(a.iter()) {
let sign_bit: i64 = (x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*y += (x + bias) >> k;
}
}

View File

@@ -1,199 +1,229 @@
use itertools::izip;
#[inline(always)]
pub fn get_digit(basek: usize, x: i64) -> i64 {
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
pub fn get_digit_i64(base2k: usize, x: i64) -> i64 {
(x << (u64::BITS - base2k as u32)) >> (u64::BITS - base2k as u32)
}
#[inline(always)]
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> basek
pub fn get_carry_i64(base2k: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> base2k
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn get_digit_i128(base2k: usize, x: i128) -> i128 {
(x << (u128::BITS - base2k as u32)) >> (u128::BITS - base2k as u32)
}
#[inline(always)]
pub fn get_carry_i128(base2k: usize, x: i128, digit: i128) -> i128 {
(x.wrapping_sub(digit)) >> base2k
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek, *x, get_digit(basek, *x));
*c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x));
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
*c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x));
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
*c = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
*c = get_carry_i64(base2k, *x, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
*c = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
*c = get_carry_i64(basek_lsh, *x, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
*c = get_carry(basek, *a, digit);
let digit: i64 = get_digit_i64(base2k, *a);
*c = get_carry_i64(base2k, *a, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
*c = get_carry(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(basek_lsh, *a);
*c = get_carry_i64(basek_lsh, *a, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
let carry: i64 = get_carry_i64(base2k, *x, digit);
let digit_plus_c: i64 = digit + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
let carry: i64 = get_carry_i64(base2k, *x, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_extract_digit_addmul_ref(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
for (r, s) in res.iter_mut().zip(src.iter_mut()) {
let digit: i64 = get_digit_i64(base2k, *s);
*s = get_carry_i64(base2k, *s, digit);
*r += digit << lsh;
}
}
#[inline(always)]
pub fn znx_normalize_digit_ref(base2k: usize, res: &mut [i64], src: &mut [i64]) {
for (r, s) in res.iter_mut().zip(src.iter_mut()) {
let ri_digit: i64 = get_digit_i64(base2k, *r);
let ri_carry: i64 = get_carry_i64(base2k, *r, ri_digit);
*r = ri_digit;
*s += ri_carry;
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
let carry: i64 = get_carry(basek, *a, digit);
let digit: i64 = get_digit_i64(base2k, *a);
let carry: i64 = get_carry_i64(base2k, *a, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
let carry: i64 = get_carry(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(basek_lsh, *a);
let carry: i64 = get_carry_i64(basek_lsh, *a, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, get_digit(basek, *x) + *c);
*x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, get_digit(basek, *a) + *c);
*x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c);
});
}
}

View File

@@ -2,8 +2,8 @@ use rand_distr::{Distribution, Normal};
use crate::source::Source;
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << basek;
pub fn znx_fill_uniform_ref(base2k: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << base2k;
let mask: u64 = pow2k - 1;
let pow2k_half: i64 = (pow2k >> 1) as i64;
res.iter_mut()

View File

@@ -11,7 +11,7 @@ pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
}
}
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -23,7 +23,7 @@ pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
}
}
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_negate_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());