Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -35,8 +35,8 @@ where
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= size {
for j in 0..size {
@@ -45,7 +45,7 @@ where
return;
}
// Inplace shift of limbs by a k/basek
// Inplace shift of limbs by a k/base2k
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
@@ -65,21 +65,21 @@ where
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
// Inplace normalization with left shift of k % base2k
if !k.is_multiple_of(base2k) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -90,8 +90,8 @@ where
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
@@ -103,12 +103,12 @@ where
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
// by k/base2k and intra-limb by base2k - k%base2k
if !k.is_multiple_of(base2k) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -116,7 +116,7 @@ where
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -124,7 +124,7 @@ where
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -133,7 +133,7 @@ where
}
}
} else {
// If k % basek = 0, then this is simply a copy.
// If k % base2k = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
@@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -166,8 +166,8 @@ where
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
return;
@@ -184,8 +184,8 @@ where
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -194,9 +194,9 @@ where
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
}
});
@@ -206,20 +206,20 @@ where
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
// Shift by multiples of base2k
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
@@ -236,7 +236,7 @@ where
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -256,8 +256,8 @@ where
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
@@ -271,8 +271,8 @@ where
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -281,9 +281,9 @@ where
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
}
}
@@ -300,16 +300,16 @@ where
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
@@ -321,9 +321,9 @@ where
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
@@ -351,7 +351,7 @@ where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let group_name: String = format!("vec_znx_lsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -366,7 +366,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -381,7 +381,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -401,7 +401,7 @@ where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let group_name: String = format!("vec_znx_lsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -416,7 +416,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -431,7 +431,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -451,7 +451,7 @@ where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let group_name: String = format!("vec_znx_rsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -466,7 +466,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -481,7 +481,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -501,7 +501,7 @@ where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let group_name: String = format!("vec_znx_rsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -516,7 +516,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -531,7 +531,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -553,7 +553,7 @@ mod tests {
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
vec_znx_sub_inplace,
},
znx::ZnxRef,
},
@@ -574,20 +574,20 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
@@ -606,7 +606,7 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -615,29 +615,29 @@ mod tests {
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
if a_size + k.div_ceil(base2k) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
@@ -656,14 +656,14 @@ mod tests {
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
}
}
}