mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -91,7 +91,7 @@ pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAdd + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add::{}", label);
|
||||
let group_name: String = format!("vec_znx_add::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -136,7 +136,7 @@ pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAddInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_add_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
|
||||
@@ -18,12 +18,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
|
||||
@@ -63,7 +63,7 @@ pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism::{}", label);
|
||||
let group_name: String = format!("vec_znx_automorphism::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -108,7 +108,7 @@ where
|
||||
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ use crate::{
|
||||
},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
|
||||
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
vec_znx::{vec_znx_rotate, vec_znx_sub_inplace},
|
||||
znx::{ZnxNegate, ZnxRotate, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
@@ -23,16 +23,16 @@ pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usiz
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
|
||||
ZNXARI: ZnxRotate + ZnxZero + ZnxSubInplace,
|
||||
{
|
||||
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
|
||||
vec_znx_sub_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
|
||||
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubNegateInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -41,7 +41,7 @@ where
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
|
||||
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -94,7 +94,7 @@ where
|
||||
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate::{}", label);
|
||||
let group_name: String = format!("vec_znx_negate::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -93,7 +93,7 @@ pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_negate_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
|
||||
@@ -6,71 +6,204 @@ use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{
|
||||
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
2 * n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
pub fn vec_znx_normalize<R, A, ZNXARI>(
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_base2k: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxAddInplace
|
||||
+ ZnxMulPowerOfTwoInplace
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep,
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxExtractDigitAddMul
|
||||
+ ZnxNormalizeDigit,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= res.n());
|
||||
assert!(carry.len() >= 2 * res.n());
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
if a_size > res_size {
|
||||
for j in (res_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
if res_base2k == a_base2k {
|
||||
if a_size > res_size {
|
||||
for j in (res_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in (1..res_size).rev() {
|
||||
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
|
||||
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
|
||||
} else {
|
||||
for j in (0..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in (1..res_size).rev() {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
|
||||
} else {
|
||||
for j in (0..a_size).rev() {
|
||||
let (a_norm, carry) = carry.split_at_mut(n);
|
||||
|
||||
// Relevant limbs of res
|
||||
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
|
||||
|
||||
// Relevant limbs of a
|
||||
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
|
||||
|
||||
// Get carry for limbs of a that have higher precision than res
|
||||
for j in (a_min_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
if a_min_size == a_size {
|
||||
ZNXARI::znx_zero(carry);
|
||||
}
|
||||
|
||||
// Maximum relevant precision of a
|
||||
let a_prec: usize = a_min_size * a_base2k;
|
||||
|
||||
// Maximum relevant precision of res
|
||||
let res_prec: usize = res_min_size * res_base2k;
|
||||
|
||||
// Res limb index
|
||||
let mut res_idx: usize = res_min_size - 1;
|
||||
|
||||
// Trackers: wow much of res is left to be populated
|
||||
// for the current limb.
|
||||
let mut res_left: usize = res_base2k;
|
||||
|
||||
for j in (0..a_min_size).rev() {
|
||||
// Trackers: wow much of a_norm is left to
|
||||
// be flushed on res.
|
||||
let mut a_left: usize = a_base2k;
|
||||
|
||||
// Normalizes the j-th limb of a and store the results into a_norm.
|
||||
// This step is required to avoid overflow in the next step,
|
||||
// which assumes that |a| is bounded by 2^{a_base2k -1}.
|
||||
if j != 0 {
|
||||
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
|
||||
}
|
||||
|
||||
// In the first iteration we need to match the precision of the input/output.
|
||||
// If a_min_size * a_base2k > res_min_size * res_base2k
|
||||
// then divround a_norm by the difference of precision and
|
||||
// acts like if a_norm has already been partially consummed.
|
||||
// Else acts like if res has been already populated
|
||||
// by the difference.
|
||||
if j == a_min_size - 1 {
|
||||
if a_prec > res_prec {
|
||||
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
|
||||
a_left -= a_prec - res_prec;
|
||||
} else if res_prec > a_prec {
|
||||
res_left -= res_prec - a_prec;
|
||||
}
|
||||
}
|
||||
|
||||
// Flushes a into res
|
||||
loop {
|
||||
// Selects the maximum amount of a that can be flushed
|
||||
let a_take: usize = a_base2k.min(a_left).min(res_left);
|
||||
|
||||
// Output limb
|
||||
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
|
||||
|
||||
// Scaling of the value to flush
|
||||
let lsh: usize = res_base2k - res_left;
|
||||
|
||||
// Extract the bits to flush on the output and updates
|
||||
// a_norm accordingly.
|
||||
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
|
||||
|
||||
// Updates the trackers
|
||||
a_left -= a_take;
|
||||
res_left -= a_take;
|
||||
|
||||
// If the current limb of res is full,
|
||||
// then normalizes this limb and adds
|
||||
// the carry on a_norm.
|
||||
if res_left == 0 {
|
||||
// Updates tracker
|
||||
res_left += res_base2k;
|
||||
|
||||
// Normalizes res and propagates the carry on a.
|
||||
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
|
||||
|
||||
// If we reached the last limb of res breaks,
|
||||
// but we might rerun the above loop if the
|
||||
// base2k of a is much smaller than the base2k
|
||||
// of res.
|
||||
if res_idx == 0 {
|
||||
ZNXARI::znx_add_inplace(carry, a_norm);
|
||||
break;
|
||||
}
|
||||
|
||||
// Else updates the limb index of res.
|
||||
res_idx -= 1
|
||||
}
|
||||
|
||||
// If a_norm is exhausted, breaks the loop.
|
||||
if a_left == 0 {
|
||||
ZNXARI::znx_add_inplace(carry, a_norm);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for j in res_min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
@@ -85,11 +218,11 @@ where
|
||||
|
||||
for j in (0..res_size).rev() {
|
||||
if j == res_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -99,7 +232,7 @@ where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize::{}", label);
|
||||
let group_name: String = format!("vec_znx_normalize::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -114,7 +247,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -129,7 +262,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
|
||||
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -149,7 +282,7 @@ where
|
||||
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_normalize_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -164,7 +297,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -177,7 +310,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
|
||||
module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -191,3 +324,83 @@ where
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_normalize_conv() {
|
||||
let n: usize = 8;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; 2 * n];
|
||||
|
||||
use crate::reference::znx::ZnxRef;
|
||||
use rug::ops::SubAssignRound;
|
||||
use rug::{Float, float::Round};
|
||||
|
||||
let mut source: Source = Source::new([1u8; 32]);
|
||||
|
||||
let prec: usize = 128;
|
||||
|
||||
let mut data: Vec<i128> = vec![0i128; n];
|
||||
|
||||
data.iter_mut().for_each(|x| *x = source.next_i128());
|
||||
|
||||
for start_base2k in 1..50 {
|
||||
for end_base2k in 1..50 {
|
||||
let end_size: usize = prec.div_ceil(end_base2k);
|
||||
|
||||
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
|
||||
want.encode_vec_i128(end_base2k, 0, prec, &data);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
|
||||
|
||||
// Creates a temporary poly where encoding is in start_base2k
|
||||
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
|
||||
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
|
||||
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
|
||||
|
||||
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
|
||||
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
|
||||
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
|
||||
|
||||
let out_prec: u32 = (end_size * end_base2k) as u32;
|
||||
|
||||
let mut data_want: Vec<Float> = (0..n)
|
||||
.map(|_| Float::with_val(out_prec as u32, 0))
|
||||
.collect();
|
||||
let mut data_res: Vec<Float> = (0..n)
|
||||
.map(|_| Float::with_val(out_prec as u32, 0))
|
||||
.collect();
|
||||
|
||||
have.decode_vec_float(end_base2k, 0, &mut data_want);
|
||||
want.decode_vec_float(end_base2k, 0, &mut data_res);
|
||||
|
||||
for i in 0..n {
|
||||
let mut err: Float = data_want[i].clone();
|
||||
err.sub_assign_round(&data_res[i], Round::Nearest);
|
||||
err = err.abs();
|
||||
|
||||
// println!(
|
||||
// "want: {} have: {} tmp: {} (want-have): {}",
|
||||
// data_want[i].to_f64(),
|
||||
// data_res[i].to_f64(),
|
||||
// data_tmp[i].to_f64(),
|
||||
// err.to_f64()
|
||||
// );
|
||||
|
||||
let err_log2: f64 = err
|
||||
.clone()
|
||||
.max(&Float::with_val(prec as u32, 1e-60))
|
||||
.log2()
|
||||
.to_f64();
|
||||
|
||||
assert!(
|
||||
err_log2 <= -(out_prec as f64) + 1.,
|
||||
"{} {}",
|
||||
err_log2,
|
||||
-(out_prec as f64) + 1.
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRotate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate::{}", label);
|
||||
let group_name: String = format!("vec_znx_rotate::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -106,7 +106,7 @@ where
|
||||
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_rotate_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
|
||||
@@ -4,18 +4,18 @@ use crate::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
pub fn vec_znx_fill_uniform_ref<R>(base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
|
||||
znx_fill_uniform_ref(base2k, res.at_mut(res_col, j), source)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_fill_normal_ref<R>(
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -32,8 +32,8 @@ pub fn vec_znx_fill_normal_ref<R>(
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
@@ -42,8 +42,15 @@ pub fn vec_znx_fill_normal_ref<R>(
|
||||
)
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
|
||||
where
|
||||
pub fn vec_znx_add_normal_ref<R>(
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -53,8 +60,8 @@ where
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
|
||||
@@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
@@ -35,8 +35,8 @@ where
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
let steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
@@ -45,7 +45,7 @@ where
|
||||
return;
|
||||
}
|
||||
|
||||
// Inplace shift of limbs by a k/basek
|
||||
// Inplace shift of limbs by a k/base2k
|
||||
if steps > 0 {
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
@@ -65,21 +65,21 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// Inplace normalization with left shift of k % basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
// Inplace normalization with left shift of k % base2k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
for j in (0..size - steps).rev() {
|
||||
if j == size - steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -90,8 +90,8 @@ where
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
let steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if steps >= res_size.min(a_size) {
|
||||
for j in 0..res_size {
|
||||
@@ -103,12 +103,12 @@ where
|
||||
let min_size: usize = a_size.min(res_size) - steps;
|
||||
|
||||
// Simply a left shifted normalization of limbs
|
||||
// by k/basek and intra-limb by basek - k%basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
// by k/base2k and intra-limb by base2k - k%base2k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
for j in (0..min_size).rev() {
|
||||
if j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
@@ -116,7 +116,7 @@ where
|
||||
);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(
|
||||
basek,
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
@@ -124,7 +124,7 @@ where
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
@@ -133,7 +133,7 @@ where
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If k % basek = 0, then this is simply a copy.
|
||||
// If k % base2k = 0, then this is simply a copy.
|
||||
for j in (0..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
|
||||
}
|
||||
@@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
@@ -166,8 +166,8 @@ where
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
let mut steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
@@ -184,8 +184,8 @@ where
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
// We rsh by an additional base2k and then lsh by base2k-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
@@ -194,9 +194,9 @@ where
|
||||
// but the carry still need to be computed.
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
if j == size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -206,20 +206,20 @@ where
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
let rhs_slice: &mut [i64] = &mut rhs[start..end];
|
||||
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
|
||||
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
|
||||
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
|
||||
});
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Shift by multiples of basek
|
||||
// Shift by multiples of base2k
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
@@ -236,7 +236,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -256,8 +256,8 @@ where
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
let mut steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if k == 0 {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
@@ -271,8 +271,8 @@ where
|
||||
return;
|
||||
}
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
// We rsh by an additional base2k and then lsh by base2k-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
@@ -281,9 +281,9 @@ where
|
||||
// but the carry still need to be computed.
|
||||
for j in (res_size..a_size + steps).rev() {
|
||||
if j == a_size + steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,16 +300,16 @@ where
|
||||
// Case if no limb of a was previously discarded
|
||||
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
base2k,
|
||||
base2k - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
base2k,
|
||||
base2k - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
@@ -321,9 +321,9 @@ where
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -351,7 +351,7 @@ where
|
||||
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_lsh_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -366,7 +366,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -381,7 +381,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -401,7 +401,7 @@ where
|
||||
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh::{}", label);
|
||||
let group_name: String = format!("vec_znx_lsh::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -416,7 +416,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -431,7 +431,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -451,7 +451,7 @@ where
|
||||
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_rsh_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -466,7 +466,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -481,7 +481,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -501,7 +501,7 @@ where
|
||||
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh::{}", label);
|
||||
let group_name: String = format!("vec_znx_rsh::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -516,7 +516,7 @@ where
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -531,7 +531,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -553,7 +553,7 @@ mod tests {
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
|
||||
vec_znx_sub_ab_inplace,
|
||||
vec_znx_sub_inplace,
|
||||
},
|
||||
znx::ZnxRef,
|
||||
},
|
||||
@@ -574,20 +574,20 @@ mod tests {
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
for k in 0..256 {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
|
||||
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
|
||||
vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
@@ -606,7 +606,7 @@ mod tests {
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
let base2k: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
@@ -615,29 +615,29 @@ mod tests {
|
||||
for a_size in [res_size - 1, res_size, res_size + 1] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
|
||||
for k in 0..res_size * basek {
|
||||
for k in 0..res_size * base2k {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
res_test.fill_uniform(50, &mut source);
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
|
||||
vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
|
||||
}
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
|
||||
}
|
||||
|
||||
// Case where res has enough to fully store a right shifted without any loss
|
||||
// In this case we can check exact equality.
|
||||
if a_size + k.div_ceil(basek) <= res_size {
|
||||
if a_size + k.div_ceil(base2k) <= res_size {
|
||||
assert_eq!(res_ref, res_test);
|
||||
|
||||
for i in 0..cols {
|
||||
@@ -656,14 +656,14 @@ mod tests {
|
||||
// res.
|
||||
} else {
|
||||
for j in 0..cols {
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
|
||||
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
|
||||
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
|
||||
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
|
||||
|
||||
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
|
||||
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ use std::hint::black_box;
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
|
||||
api::{ModuleNew, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
|
||||
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
oep::{ModuleNewImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl},
|
||||
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
@@ -64,11 +64,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_sub_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
ZNXARI: ZnxSubInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -84,15 +84,15 @@ where
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
ZNXARI::znx_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub fn vec_znx_sub_negate_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
|
||||
ZNXARI: ZnxSubNegateInplace + ZnxNegateInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -108,7 +108,7 @@ where
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
@@ -120,7 +120,7 @@ pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub::{}", label);
|
||||
let group_name: String = format!("vec_znx_sub::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
@@ -161,17 +161,17 @@ where
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
|
||||
pub fn bench_vec_znx_sub_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_sub_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
|
||||
Module<B>: VecZnxSubInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
@@ -190,7 +190,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
|
||||
module.vec_znx_sub_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -205,17 +205,17 @@ where
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
|
||||
pub fn bench_vec_znx_sub_negate_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubNegateInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
|
||||
let group_name: String = format!("vec_znx_sub_negate_inplace::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
|
||||
Module<B>: VecZnxSubNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
@@ -234,7 +234,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
|
||||
module.vec_znx_sub_negate_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
|
||||
use crate::{
|
||||
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
|
||||
reference::znx::{ZnxSub, ZnxSubInplace, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
@@ -19,12 +19,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
@@ -44,7 +39,7 @@ pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
ZNXARI: ZnxSubInplace,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -54,5 +49,5 @@ where
|
||||
assert!(res_limb < res.size());
|
||||
}
|
||||
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
ZNXARI::znx_sub_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user