Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -1,5 +1,6 @@
mod add;
mod automorphism;
mod mul;
mod neg;
mod normalization;
mod sub;
@@ -7,6 +8,7 @@ mod switch_ring;
pub(crate) use add::*;
pub(crate) use automorphism::*;
pub(crate) use mul::*;
pub(crate) use neg::*;
pub(crate) use normalization::*;
pub(crate) use sub::*;

View File

@@ -0,0 +1,318 @@
/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use poulpy_hal::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(aa);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}
/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) {
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(rr);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
}
/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use crate::cpu_fft64_avx::znx_avx::znx_add_inplace_avx;
znx_add_inplace_avx(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let out: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -6,14 +6,14 @@ use std::arch::x86_64::__m256i;
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) {
fn normalize_consts_avx(base2k: usize) -> (__m256i, __m256i, __m256i, __m256i) {
use std::arch::x86_64::_mm256_set1_epi64x;
assert!((1..=63).contains(&basek));
let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set
let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1
let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set
let sh_k: __m256i = _mm256_set1_epi64x(basek as i64);
assert!((1..=63).contains(&base2k));
let mask_k: i64 = ((1u64 << base2k) - 1) as i64; // 0..k-1 bits set
let sign_k: i64 = (1u64 << (base2k - 1)) as i64; // bit k-1
let topmask: i64 = (!0u64 << (64 - base2k)) as i64; // top k bits set
let sh_k: __m256i = _mm256_set1_epi64x(base2k as i64);
(
_mm256_set1_epi64x(mask_k), // mask_k_vec
_mm256_set1_epi64x(sign_k), // sign_k_vec
@@ -46,14 +46,14 @@ fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i {
unsafe fn get_carry_avx(
x: __m256i,
digit: __m256i,
basek: __m256i, // _mm256_set1_epi64x(k)
base2k: __m256i, // _mm256_set1_epi64x(k)
top_mask: __m256i, // (!0 << (64 - k)) broadcast
) -> __m256i {
use std::arch::x86_64::{
__m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64,
};
let diff: __m256i = _mm256_sub_epi64(x, digit);
let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >>
let lsr: __m256i = _mm256_srlv_epi64(diff, base2k); // logical >>
let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0
let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative
_mm256_or_si256(lsr, fill)
@@ -61,13 +61,121 @@ unsafe fn get_carry_avx(
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
/// `res` and `src` must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert_eq!(res.len(), src.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256,
};
let n: usize = res.len();
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
// load source & extract digit/carry
let sv: __m256i = _mm256_loadu_si256(ss);
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
// res += (digit << lsh)
let rv: __m256i = _mm256_loadu_si256(rr);
let madd: __m256i = _mm256_sllv_epi64(digit_256, lsh_v);
let sum: __m256i = _mm256_add_epi64(rv, madd);
_mm256_storeu_si256(rr, sum);
_mm256_storeu_si256(ss, carry_256);
rr = rr.add(1);
ss = ss.add(1);
}
}
// tail (scalar)
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_extract_digit_addmul_ref;
let off: usize = span << 2;
znx_extract_digit_addmul_ref(base2k, lsh, &mut res[off..], &mut src[off..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// `res` and `src` must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len());
}
use std::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
unsafe {
// Pointers to 256-bit lanes
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// Constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
// Load res lane
let rv: __m256i = _mm256_loadu_si256(rr);
// Extract digit and carry from res
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
// src += carry
let sv: __m256i = _mm256_loadu_si256(ss);
let sum: __m256i = _mm256_add_epi64(sv, carry_256);
_mm256_storeu_si256(ss, sum);
_mm256_storeu_si256(rr, digit_256);
rr = rr.add(1);
ss = ss.add(1);
}
}
// scalar tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_digit_ref;
let off = span << 2;
znx_normalize_digit_ref(base2k, &mut res[off..], &mut src[off..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256};
@@ -81,19 +189,19 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
normalize_consts_avx(basek)
normalize_consts_avx(base2k)
} else {
normalize_consts_avx(basek - lsh)
normalize_consts_avx(base2k - lsh)
};
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(cc, carry_256);
@@ -106,7 +214,7 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref;
znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
znx_normalize_first_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -114,11 +222,11 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -132,16 +240,16 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -150,18 +258,18 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
cc = cc.add(1);
}
} else {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -176,7 +284,7 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref;
znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_first_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -184,12 +292,12 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -204,16 +312,16 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let av: __m256i = _mm256_loadu_si256(aa);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -225,18 +333,18 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let av: __m256i = _mm256_loadu_si256(aa);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -253,7 +361,7 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
znx_normalize_first_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -266,11 +374,11 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -279,7 +387,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -287,13 +395,13 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
if lsh == 0 {
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -307,20 +415,20 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -337,7 +445,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref;
znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_middle_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -345,11 +453,11 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -358,7 +466,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
@@ -366,13 +474,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
if lsh == 0 {
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -385,20 +493,20 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -414,7 +522,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref;
znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
znx_normalize_middle_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -422,12 +530,12 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -436,7 +544,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -445,13 +553,13 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
if lsh == 0 {
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let av: __m256i = _mm256_loadu_si256(aa);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(aa_256, mask, sign);
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(av, mask, sign);
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -466,20 +574,20 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let av: __m256i = _mm256_loadu_si256(aa);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -498,7 +606,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
znx_normalize_middle_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -511,11 +619,11 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -524,7 +632,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
let span: usize = n >> 2;
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -547,7 +655,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -573,7 +681,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref;
znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_final_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -581,12 +689,12 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -595,7 +703,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let span: usize = n >> 2;
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -620,7 +728,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -647,7 +755,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
znx_normalize_final_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -658,9 +766,9 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
mod tests {
use poulpy_hal::reference::znx::{
get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref,
znx_normalize_middle_step_ref,
get_carry_i64, get_digit_i64, znx_extract_digit_addmul_ref, znx_normalize_digit_ref,
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_inplace_ref,
znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
};
use super::*;
@@ -670,7 +778,7 @@ mod tests {
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
fn test_get_digit_avx_internal() {
let basek: usize = 12;
let base2k: usize = 12;
let x: [i64; 4] = [
7638646372408325293,
-61440197422348985,
@@ -678,15 +786,15 @@ mod tests {
-4835376105455195188,
];
let y0: Vec<i64> = vec![
get_digit(basek, x[0]),
get_digit(basek, x[1]),
get_digit(basek, x[2]),
get_digit(basek, x[3]),
get_digit_i64(base2k, x[0]),
get_digit_i64(base2k, x[1]),
get_digit_i64(base2k, x[2]),
get_digit_i64(base2k, x[3]),
];
let mut y1: Vec<i64> = vec![0i64; 4];
unsafe {
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
let digit: __m256i = get_digit_avx(x_256, mask, sign);
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
}
@@ -707,7 +815,7 @@ mod tests {
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
fn test_get_carry_avx_internal() {
let basek: usize = 12;
let base2k: usize = 12;
let x: [i64; 4] = [
7638646372408325293,
-61440197422348985,
@@ -716,16 +824,16 @@ mod tests {
];
let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840];
let y0: Vec<i64> = vec![
get_carry(basek, x[0], carry[0]),
get_carry(basek, x[1], carry[1]),
get_carry(basek, x[2], carry[2]),
get_carry(basek, x[3], carry[3]),
get_carry_i64(base2k, x[0], carry[0]),
get_carry_i64(base2k, x[1], carry[1]),
get_carry_i64(base2k, x[2], carry[2]),
get_carry_i64(base2k, x[3], carry[3]),
];
let mut y1: Vec<i64> = vec![0i64; 4];
unsafe {
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek);
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
}
@@ -762,16 +870,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_first_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_first_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -807,16 +915,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_middle_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_middle_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -852,16 +960,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_final_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_final_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -898,16 +1006,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_first_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_first_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -944,16 +1052,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_middle_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_middle_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -990,16 +1098,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_final_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_final_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -1015,4 +1123,86 @@ mod tests {
test_znx_normalize_final_step_avx_internal();
}
}
#[target_feature(enable = "avx2")]
fn znx_extract_digit_addmul_internal() {
let mut y0: [i64; 4] = [
7638646372408325293,
-61440197422348985,
6835891051541717957,
-4835376105455195188,
];
let mut y1: [i64; 4] = y0;
let mut c0: [i64; 4] = [
621182201135793202,
9000856573317006236,
5542252755421113668,
-6036847263131690631,
];
let mut c1: [i64; 4] = c0;
let base2k: usize = 12;
znx_extract_digit_addmul_ref(base2k, 0, &mut y0, &mut c0);
znx_extract_digit_addmul_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_extract_digit_addmul_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_extract_digit_addmul_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
}
#[test]
fn test_znx_extract_digit_addmul_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
znx_extract_digit_addmul_internal();
}
}
#[target_feature(enable = "avx2")]
fn znx_normalize_digit_internal() {
let mut y0: [i64; 4] = [
7638646372408325293,
-61440197422348985,
6835891051541717957,
-4835376105455195188,
];
let mut y1: [i64; 4] = y0;
let mut c0: [i64; 4] = [
621182201135793202,
9000856573317006236,
5542252755421113668,
-6036847263131690631,
];
let mut c1: [i64; 4] = c0;
let base2k: usize = 12;
znx_normalize_digit_ref(base2k, &mut y0, &mut c0);
znx_normalize_digit_avx(base2k, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
}
#[test]
fn test_znx_normalize_digit_internal_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
znx_normalize_digit_internal();
}
}
}

View File

@@ -41,7 +41,7 @@ pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -67,9 +67,9 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
use poulpy_hal::reference::znx::znx_sub_inplace_ref;
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}
@@ -77,7 +77,7 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -103,8 +103,8 @@ pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref;
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}