mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use itertools::izip;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_digit(basek: usize, x: i64) -> i64 {
|
||||
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
|
||||
(x.wrapping_sub(digit)) >> basek
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek, *x, get_digit(basek, *x));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
*c = get_carry(basek, *x, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
*c = get_carry(basek_lsh, *x, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
*c = get_carry(basek, *a, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
*c = get_carry(basek_lsh, *a, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
let carry: i64 = get_carry(basek, *a, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
let carry: i64 = get_carry(basek_lsh, *a, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *x) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *a) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user