mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -1,5 +1,6 @@
|
||||
mod add;
|
||||
mod automorphism;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod normalization;
|
||||
mod sub;
|
||||
@@ -7,6 +8,7 @@ mod switch_ring;
|
||||
|
||||
pub(crate) use add::*;
|
||||
pub(crate) use automorphism::*;
|
||||
pub(crate) use mul::*;
|
||||
pub(crate) use neg::*;
|
||||
pub(crate) use normalization::*;
|
||||
pub(crate) use sub::*;
|
||||
|
||||
318
poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs
Normal file
318
poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
use poulpy_hal::reference::znx::znx_copy_ref;
|
||||
znx_copy_ref(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_sll_epi64(x, cnt128);
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
|
||||
|
||||
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x = _mm256_loadu_si256(aa);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let y: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
|
||||
|
||||
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) {
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(rr);
|
||||
let y: __m256i = _mm256_sll_epi64(x, cnt128);
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
|
||||
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x = _mm256_loadu_si256(rr);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let y: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
|
||||
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
use crate::cpu_fft64_avx::znx_avx::znx_add_inplace_avx;
|
||||
|
||||
znx_add_inplace_avx(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_loadu_si256(rr);
|
||||
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
|
||||
|
||||
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_loadu_si256(rr);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let out: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
|
||||
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
@@ -6,14 +6,14 @@ use std::arch::x86_64::__m256i;
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) {
|
||||
fn normalize_consts_avx(base2k: usize) -> (__m256i, __m256i, __m256i, __m256i) {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
assert!((1..=63).contains(&basek));
|
||||
let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set
|
||||
let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1
|
||||
let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set
|
||||
let sh_k: __m256i = _mm256_set1_epi64x(basek as i64);
|
||||
assert!((1..=63).contains(&base2k));
|
||||
let mask_k: i64 = ((1u64 << base2k) - 1) as i64; // 0..k-1 bits set
|
||||
let sign_k: i64 = (1u64 << (base2k - 1)) as i64; // bit k-1
|
||||
let topmask: i64 = (!0u64 << (64 - base2k)) as i64; // top k bits set
|
||||
let sh_k: __m256i = _mm256_set1_epi64x(base2k as i64);
|
||||
(
|
||||
_mm256_set1_epi64x(mask_k), // mask_k_vec
|
||||
_mm256_set1_epi64x(sign_k), // sign_k_vec
|
||||
@@ -46,14 +46,14 @@ fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i {
|
||||
unsafe fn get_carry_avx(
|
||||
x: __m256i,
|
||||
digit: __m256i,
|
||||
basek: __m256i, // _mm256_set1_epi64x(k)
|
||||
base2k: __m256i, // _mm256_set1_epi64x(k)
|
||||
top_mask: __m256i, // (!0 << (64 - k)) broadcast
|
||||
) -> __m256i {
|
||||
use std::arch::x86_64::{
|
||||
__m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64,
|
||||
};
|
||||
let diff: __m256i = _mm256_sub_epi64(x, digit);
|
||||
let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >>
|
||||
let lsr: __m256i = _mm256_srlv_epi64(diff, base2k); // logical >>
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative
|
||||
_mm256_or_si256(lsr, fill)
|
||||
@@ -61,13 +61,121 @@ unsafe fn get_carry_avx(
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
/// `res` and `src` must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(res.len(), src.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{
|
||||
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
// load source & extract digit/carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
|
||||
|
||||
// res += (digit << lsh)
|
||||
let rv: __m256i = _mm256_loadu_si256(rr);
|
||||
let madd: __m256i = _mm256_sllv_epi64(digit_256, lsh_v);
|
||||
let sum: __m256i = _mm256_add_epi64(rv, madd);
|
||||
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
_mm256_storeu_si256(ss, carry_256);
|
||||
|
||||
rr = rr.add(1);
|
||||
ss = ss.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail (scalar)
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_extract_digit_addmul_ref;
|
||||
|
||||
let off: usize = span << 2;
|
||||
znx_extract_digit_addmul_ref(base2k, lsh, &mut res[off..], &mut src[off..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// `res` and `src` must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
let n: usize = res.len();
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
// Pointers to 256-bit lanes
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// Constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
// Load res lane
|
||||
let rv: __m256i = _mm256_loadu_si256(rr);
|
||||
|
||||
// Extract digit and carry from res
|
||||
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
|
||||
|
||||
// src += carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
let sum: __m256i = _mm256_add_epi64(sv, carry_256);
|
||||
|
||||
_mm256_storeu_si256(ss, sum);
|
||||
_mm256_storeu_si256(rr, digit_256);
|
||||
|
||||
rr = rr.add(1);
|
||||
ss = ss.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_digit_ref;
|
||||
|
||||
let off = span << 2;
|
||||
znx_normalize_digit_ref(base2k, &mut res[off..], &mut src[off..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256};
|
||||
@@ -81,19 +189,19 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
|
||||
normalize_consts_avx(basek)
|
||||
normalize_consts_avx(base2k)
|
||||
} else {
|
||||
normalize_consts_avx(basek - lsh)
|
||||
normalize_consts_avx(base2k - lsh)
|
||||
};
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
|
||||
@@ -106,7 +214,7 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref;
|
||||
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_first_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,11 +222,11 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -132,16 +240,16 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -150,18 +258,18 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
cc = cc.add(1);
|
||||
}
|
||||
} else {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -176,7 +284,7 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref;
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_first_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,12 +292,12 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -204,16 +312,16 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -225,18 +333,18 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -253,7 +361,7 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
|
||||
|
||||
znx_normalize_first_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -266,11 +374,11 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -279,7 +387,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -287,13 +395,13 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -307,20 +415,20 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -337,7 +445,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref;
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,11 +453,11 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -358,7 +466,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
|
||||
@@ -366,13 +474,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -385,20 +493,20 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -414,7 +522,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref;
|
||||
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_middle_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,12 +530,12 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -436,7 +544,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -445,13 +553,13 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(av, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -466,20 +574,20 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -498,7 +606,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
|
||||
|
||||
znx_normalize_middle_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -511,11 +619,11 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_final_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -524,7 +632,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -547,7 +655,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -573,7 +681,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref;
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_final_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,12 +689,12 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -595,7 +703,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -620,7 +728,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -647,7 +755,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
|
||||
|
||||
znx_normalize_final_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -658,9 +766,9 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
mod tests {
|
||||
use poulpy_hal::reference::znx::{
|
||||
get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref,
|
||||
znx_normalize_middle_step_ref,
|
||||
get_carry_i64, get_digit_i64, znx_extract_digit_addmul_ref, znx_normalize_digit_ref,
|
||||
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_inplace_ref,
|
||||
znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -670,7 +778,7 @@ mod tests {
|
||||
#[allow(dead_code)]
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn test_get_digit_avx_internal() {
|
||||
let basek: usize = 12;
|
||||
let base2k: usize = 12;
|
||||
let x: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
@@ -678,15 +786,15 @@ mod tests {
|
||||
-4835376105455195188,
|
||||
];
|
||||
let y0: Vec<i64> = vec![
|
||||
get_digit(basek, x[0]),
|
||||
get_digit(basek, x[1]),
|
||||
get_digit(basek, x[2]),
|
||||
get_digit(basek, x[3]),
|
||||
get_digit_i64(base2k, x[0]),
|
||||
get_digit_i64(base2k, x[1]),
|
||||
get_digit_i64(base2k, x[2]),
|
||||
get_digit_i64(base2k, x[3]),
|
||||
];
|
||||
let mut y1: Vec<i64> = vec![0i64; 4];
|
||||
unsafe {
|
||||
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_digit_avx(x_256, mask, sign);
|
||||
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
|
||||
}
|
||||
@@ -707,7 +815,7 @@ mod tests {
|
||||
#[allow(dead_code)]
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn test_get_carry_avx_internal() {
|
||||
let basek: usize = 12;
|
||||
let base2k: usize = 12;
|
||||
let x: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
@@ -716,16 +824,16 @@ mod tests {
|
||||
];
|
||||
let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840];
|
||||
let y0: Vec<i64> = vec![
|
||||
get_carry(basek, x[0], carry[0]),
|
||||
get_carry(basek, x[1], carry[1]),
|
||||
get_carry(basek, x[2], carry[2]),
|
||||
get_carry(basek, x[3], carry[3]),
|
||||
get_carry_i64(base2k, x[0], carry[0]),
|
||||
get_carry_i64(base2k, x[1], carry[1]),
|
||||
get_carry_i64(base2k, x[2], carry[2]),
|
||||
get_carry_i64(base2k, x[3], carry[3]),
|
||||
];
|
||||
let mut y1: Vec<i64> = vec![0i64; 4];
|
||||
unsafe {
|
||||
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
|
||||
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
|
||||
let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
|
||||
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
|
||||
}
|
||||
@@ -762,16 +870,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_first_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_first_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -807,16 +915,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -852,16 +960,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_final_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_final_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -898,16 +1006,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_first_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_first_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -944,16 +1052,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_middle_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_middle_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -990,16 +1098,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_final_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_final_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -1015,4 +1123,86 @@ mod tests {
|
||||
test_znx_normalize_final_step_avx_internal();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn znx_extract_digit_addmul_internal() {
|
||||
let mut y0: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
6835891051541717957,
|
||||
-4835376105455195188,
|
||||
];
|
||||
let mut y1: [i64; 4] = y0;
|
||||
|
||||
let mut c0: [i64; 4] = [
|
||||
621182201135793202,
|
||||
9000856573317006236,
|
||||
5542252755421113668,
|
||||
-6036847263131690631,
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
znx_extract_digit_addmul_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_extract_digit_addmul_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_extract_digit_addmul_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_extract_digit_addmul_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_znx_extract_digit_addmul_avx() {
|
||||
if !std::is_x86_feature_detected!("avx2") {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
};
|
||||
unsafe {
|
||||
znx_extract_digit_addmul_internal();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn znx_normalize_digit_internal() {
|
||||
let mut y0: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
6835891051541717957,
|
||||
-4835376105455195188,
|
||||
];
|
||||
let mut y1: [i64; 4] = y0;
|
||||
|
||||
let mut c0: [i64; 4] = [
|
||||
621182201135793202,
|
||||
9000856573317006236,
|
||||
5542252755421113668,
|
||||
-6036847263131690631,
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
znx_normalize_digit_ref(base2k, &mut y0, &mut c0);
|
||||
znx_normalize_digit_avx(base2k, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_znx_normalize_digit_internal_avx() {
|
||||
if !std::is_x86_feature_detected!("avx2") {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
};
|
||||
unsafe {
|
||||
znx_normalize_digit_internal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
@@ -67,9 +67,9 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
|
||||
use poulpy_hal::reference::znx::znx_sub_inplace_ref;
|
||||
|
||||
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
@@ -103,8 +103,8 @@ pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
|
||||
use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref;
|
||||
|
||||
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user