Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -91,7 +91,7 @@ pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add::{}", label);
let group_name: String = format!("vec_znx_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -136,7 +136,7 @@ pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add_inplace::{}", label);
let group_name: String = format!("vec_znx_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -18,12 +18,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {

View File

@@ -63,7 +63,7 @@ pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_automorphism::{}", label);
let group_name: String = format!("vec_znx_automorphism::{label}");
let mut group = c.benchmark_group(group_name);
@@ -108,7 +108,7 @@ where
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -9,8 +9,8 @@ use crate::{
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
vec_znx::{vec_znx_rotate, vec_znx_sub_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
},
source::Source,
};
@@ -23,16 +23,16 @@ pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usiz
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubInplace,
{
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
}
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubNegateInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
@@ -41,7 +41,7 @@ where
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), tmp);
}
}
@@ -49,7 +49,7 @@ pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one::{label}");
let mut group = c.benchmark_group(group_name);
@@ -94,7 +94,7 @@ where
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -49,7 +49,7 @@ pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate::{}", label);
let group_name: String = format!("vec_znx_negate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -93,7 +93,7 @@ pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
let group_name: String = format!("vec_znx_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -6,71 +6,204 @@ use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_normalize<R, A, ZNXARI>(
res_base2k: usize,
res: &mut R,
res_col: usize,
a_base2k: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep,
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
assert!(carry.len() >= 2 * res.n());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size = a.size();
let a_size: usize = a.size();
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
if res_base2k == a_base2k {
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
let (a_norm, carry) = carry.split_at_mut(n);
// Relevant limbs of res
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
// Relevant limbs of a
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
// Get carry for limbs of a that have higher precision than res
for j in (a_min_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
if a_min_size == a_size {
ZNXARI::znx_zero(carry);
}
// Maximum relevant precision of a
let a_prec: usize = a_min_size * a_base2k;
// Maximum relevant precision of res
let res_prec: usize = res_min_size * res_base2k;
// Res limb index
let mut res_idx: usize = res_min_size - 1;
// Trackers: wow much of res is left to be populated
// for the current limb.
let mut res_left: usize = res_base2k;
for j in (0..a_min_size).rev() {
// Trackers: wow much of a_norm is left to
// be flushed on res.
let mut a_left: usize = a_base2k;
// Normalizes the j-th limb of a and store the results into a_norm.
// This step is required to avoid overflow in the next step,
// which assumes that |a| is bounded by 2^{a_base2k -1}.
if j != 0 {
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
}
// In the first iteration we need to match the precision of the input/output.
// If a_min_size * a_base2k > res_min_size * res_base2k
// then divround a_norm by the difference of precision and
// acts like if a_norm has already been partially consummed.
// Else acts like if res has been already populated
// by the difference.
if j == a_min_size - 1 {
if a_prec > res_prec {
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
a_left -= a_prec - res_prec;
} else if res_prec > a_prec {
res_left -= res_prec - a_prec;
}
}
// Flushes a into res
loop {
// Selects the maximum amount of a that can be flushed
let a_take: usize = a_base2k.min(a_left).min(res_left);
// Output limb
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
// Scaling of the value to flush
let lsh: usize = res_base2k - res_left;
// Extract the bits to flush on the output and updates
// a_norm accordingly.
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
// Updates the trackers
a_left -= a_take;
res_left -= a_take;
// If the current limb of res is full,
// then normalizes this limb and adds
// the carry on a_norm.
if res_left == 0 {
// Updates tracker
res_left += res_base2k;
// Normalizes res and propagates the carry on a.
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
// If we reached the last limb of res breaks,
// but we might rerun the above loop if the
// base2k of a is much smaller than the base2k
// of res.
if res_idx == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
// Else updates the limb index of res.
res_idx -= 1
}
// If a_norm is exhausted, breaks the loop.
if a_left == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
}
}
for j in res_min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
@@ -85,11 +218,11 @@ where
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
}
}
}
@@ -99,7 +232,7 @@ where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{}", label);
let group_name: String = format!("vec_znx_normalize::{label}");
let mut group = c.benchmark_group(group_name);
@@ -114,7 +247,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -129,7 +262,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
}
black_box(());
}
@@ -149,7 +282,7 @@ where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
let group_name: String = format!("vec_znx_normalize_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -164,7 +297,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -177,7 +310,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
}
black_box(());
}
@@ -191,3 +324,83 @@ where
group.finish();
}
#[test]
fn test_vec_znx_normalize_conv() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; 2 * n];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let mut data: Vec<i128> = vec![0i128; n];
data.iter_mut().for_each(|x| *x = source.next_i128());
for start_base2k in 1..50 {
for end_base2k in 1..50 {
let end_size: usize = prec.div_ceil(end_base2k);
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
want.encode_vec_i128(end_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
// Creates a temporary poly where encoding is in start_base2k
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
let out_prec: u32 = (end_size * end_base2k) as u32;
let mut data_want: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
let mut data_res: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
have.decode_vec_float(end_base2k, 0, &mut data_want);
want.decode_vec_float(end_base2k, 0, &mut data_res);
for i in 0..n {
let mut err: Float = data_want[i].clone();
err.sub_assign_round(&data_res[i], Round::Nearest);
err = err.abs();
// println!(
// "want: {} have: {} tmp: {} (want-have): {}",
// data_want[i].to_f64(),
// data_res[i].to_f64(),
// data_tmp[i].to_f64(),
// err.to_f64()
// );
let err_log2: f64 = err
.clone()
.max(&Float::with_val(prec as u32, 1e-60))
.log2()
.to_f64();
assert!(
err_log2 <= -(out_prec as f64) + 1.,
"{} {}",
err_log2,
-(out_prec as f64) + 1.
)
}
}
}
}

View File

@@ -61,7 +61,7 @@ pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_rotate::{}", label);
let group_name: String = format!("vec_znx_rotate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -106,7 +106,7 @@ where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
let group_name: String = format!("vec_znx_rotate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -4,18 +4,18 @@ use crate::{
source::Source,
};
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
pub fn vec_znx_fill_uniform_ref<R>(base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
znx_fill_uniform_ref(base2k, res.at_mut(res_col, j), source)
}
}
pub fn vec_znx_fill_normal_ref<R>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -32,8 +32,8 @@ pub fn vec_znx_fill_normal_ref<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -42,8 +42,15 @@ pub fn vec_znx_fill_normal_ref<R>(
)
}
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
where
pub fn vec_znx_add_normal_ref<R>(
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -53,8 +60,8 @@ where
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,

View File

@@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -35,8 +35,8 @@ where
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= size {
for j in 0..size {
@@ -45,7 +45,7 @@ where
return;
}
// Inplace shift of limbs by a k/basek
// Inplace shift of limbs by a k/base2k
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
@@ -65,21 +65,21 @@ where
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
// Inplace normalization with left shift of k % base2k
if !k.is_multiple_of(base2k) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -90,8 +90,8 @@ where
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
@@ -103,12 +103,12 @@ where
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
// by k/base2k and intra-limb by base2k - k%base2k
if !k.is_multiple_of(base2k) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -116,7 +116,7 @@ where
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -124,7 +124,7 @@ where
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -133,7 +133,7 @@ where
}
}
} else {
// If k % basek = 0, then this is simply a copy.
// If k % base2k = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
@@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -166,8 +166,8 @@ where
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
return;
@@ -184,8 +184,8 @@ where
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -194,9 +194,9 @@ where
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
}
});
@@ -206,20 +206,20 @@ where
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
// Shift by multiples of base2k
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
@@ -236,7 +236,7 @@ where
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -256,8 +256,8 @@ where
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
@@ -271,8 +271,8 @@ where
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -281,9 +281,9 @@ where
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
}
}
@@ -300,16 +300,16 @@ where
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
@@ -321,9 +321,9 @@ where
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
@@ -351,7 +351,7 @@ where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let group_name: String = format!("vec_znx_lsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -366,7 +366,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -381,7 +381,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -401,7 +401,7 @@ where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let group_name: String = format!("vec_znx_lsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -416,7 +416,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -431,7 +431,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -451,7 +451,7 @@ where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let group_name: String = format!("vec_znx_rsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -466,7 +466,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -481,7 +481,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -501,7 +501,7 @@ where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let group_name: String = format!("vec_znx_rsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -516,7 +516,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -531,7 +531,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -553,7 +553,7 @@ mod tests {
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
vec_znx_sub_inplace,
},
znx::ZnxRef,
},
@@ -574,20 +574,20 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
@@ -606,7 +606,7 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -615,29 +615,29 @@ mod tests {
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
if a_size + k.div_ceil(base2k) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
@@ -656,14 +656,14 @@ mod tests {
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
}
}
}

View File

@@ -3,10 +3,10 @@ use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
api::{ModuleNew, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
oep::{ModuleNewImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
source::Source,
};
@@ -64,11 +64,11 @@ where
}
}
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -84,15 +84,15 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_negate_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
ZNXARI: ZnxSubNegateInplace + ZnxNegateInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -108,7 +108,7 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
@@ -120,7 +120,7 @@ pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
{
let group_name: String = format!("vec_znx_sub::{}", label);
let group_name: String = format!("vec_znx_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -161,17 +161,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
Module<B>: VecZnxSubInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -190,7 +190,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
module.vec_znx_sub_inplace(&mut b, i, &a, i);
}
black_box(());
}
@@ -205,17 +205,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_negate_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubNegateInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
Module<B>: VecZnxSubNegateInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -234,7 +234,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
module.vec_znx_sub_negate_inplace(&mut b, i, &a, i);
}
black_box(());
}

View File

@@ -1,7 +1,7 @@
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
use crate::{
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
reference::znx::{ZnxSub, ZnxSubInplace, ZnxZero},
};
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
@@ -19,12 +19,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {
@@ -44,7 +39,7 @@ pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -54,5 +49,5 @@ where
assert!(res_limb < res.size());
}
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}