Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -0,0 +1,76 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_add_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
bb = bb.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_add_ref;
znx_add_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_add_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_add_inplace_ref;
znx_add_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -0,0 +1,133 @@
use core::arch::x86_64::*;
#[inline]
fn inv_mod_pow2(p: usize, bits: u32) -> usize {
// Compute p^{-1} mod 2^bits (p must be odd) through Hensel lifting.
debug_assert!(p % 2 == 1);
let mut x: usize = 1usize; // inverse mod 2
let mut i: u32 = 1;
while i < bits {
// x <- x * (2 - p*x) mod 2^(2^i) (wrapping arithmetic)
x = x.wrapping_mul(2usize.wrapping_sub(p.wrapping_mul(x)));
i <<= 1;
}
x & ((1usize << bits) - 1)
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) {
debug_assert_eq!(res.len(), a.len());
let n: usize = res.len();
if n == 0 {
return;
}
debug_assert!(n.is_power_of_two(), "n must be power of two");
debug_assert!(p & 1 == 1, "p must be odd (invertible mod 2n)");
if n < 4 {
use poulpy_hal::reference::znx::znx_automorphism_ref;
znx_automorphism_ref(p, res, a);
return;
}
unsafe {
let two_n: usize = n << 1;
let span: usize = n >> 2;
let bits: u32 = (two_n as u64).trailing_zeros();
let mask_2n: usize = two_n - 1;
let mask_1n: usize = n - 1;
// p mod 2n (positive)
let p_2n: usize = (((p & mask_2n as i64) + two_n as i64) as usize) & mask_2n;
// p^-1 mod 2n
let inv: usize = inv_mod_pow2(p_2n, bits);
// Broadcast constants
let n_minus1_vec: __m256i = _mm256_set1_epi64x((n as i64) - 1);
let mask_2n_vec: __m256i = _mm256_set1_epi64x(mask_2n as i64);
let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64);
// Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n)
let lane_offsets: __m256i = _mm256_set_epi64x(
((inv * 3) & mask_2n) as i64,
((inv * 2) & mask_2n) as i64,
inv as i64,
0i64,
);
// t_base = (j * inv) mod 2n.
let mut t_base: usize = 0;
let step: usize = (inv << 2) & mask_2n;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let aa: *const i64 = a.as_ptr();
for _ in 0..span {
// t_vec = (t_base + [0, inv, 2*inv, 3*inv]) & (2n-1)
let t_base_vec: __m256i = _mm256_set1_epi64x(t_base as i64);
let t_vec: __m256i = _mm256_and_si256(_mm256_add_epi64(t_base_vec, lane_offsets), mask_2n_vec);
// idx = t_vec & (n-1)
let idx_vec: __m256i = _mm256_and_si256(t_vec, mask_1n_vec);
// sign = t >= n ? -1 : 0 (mask of all-ones where negate)
let sign_mask: __m256i = _mm256_cmpgt_epi64(t_vec, n_minus1_vec);
// gather a[idx] (scale = 8 bytes per i64)
let vals: __m256i = _mm256_i64gather_epi64(aa, idx_vec, 8);
// Conditional negate: (vals ^ sign_mask) - sign_mask
let vals_x: __m256i = _mm256_xor_si256(vals, sign_mask);
let out: __m256i = _mm256_sub_epi64(vals_x, sign_mask);
// store to res[j..j+4]
_mm256_storeu_si256(rr, out);
// advance
rr = rr.add(1);
t_base = (t_base + step) & mask_2n;
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(all(test, any(target_arch = "x86_64", target_arch = "x86")))]
mod tests {
use poulpy_hal::reference::znx::znx_automorphism_ref;
use super::*;
#[target_feature(enable = "avx2", enable = "fma")]
fn test_znx_automorphism_internal() {
let a: [i64; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let p: i64 = -5;
let mut r0: Vec<i64> = vec![0i64; a.len()];
let mut r1: Vec<i64> = vec![0i64; a.len()];
znx_automorphism_ref(p, &mut r0, &a);
znx_automorphism_avx(p, &mut r1, &a);
assert_eq!(r0, r1);
}
#[test]
fn test_znx_automorphism_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
test_znx_automorphism_internal();
}
}
}

View File

@@ -0,0 +1,13 @@
mod add;
mod automorphism;
mod neg;
mod normalization;
mod sub;
mod switch_ring;
pub(crate) use add::*;
pub(crate) use automorphism::*;
pub(crate) use neg::*;
pub(crate) use normalization::*;
pub(crate) use sub::*;
pub(crate) use switch_ring::*;

View File

@@ -0,0 +1,64 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_negate_avx(res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len())
}
let n: usize = res.len();
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let span: usize = n >> 2;
unsafe {
let mut aa: *const __m256i = src.as_ptr() as *const __m256i;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let v: __m256i = _mm256_loadu_si256(aa);
let neg: __m256i = _mm256_sub_epi64(zero, v);
_mm256_storeu_si256(rr, neg);
aa = aa.add(1);
rr = rr.add(1);
}
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_negate_ref;
znx_negate_ref(&mut res[span << 2..], &src[span << 2..])
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_negate_inplace_avx(res: &mut [i64]) {
let n: usize = res.len();
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let v: __m256i = _mm256_loadu_si256(rr);
let neg: __m256i = _mm256_sub_epi64(zero, v);
_mm256_storeu_si256(rr, neg);
rr = rr.add(1);
}
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_negate_inplace_ref;
znx_negate_inplace_ref(&mut res[span << 2..])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
bb = bb.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ref;
znx_sub_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(rr));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -0,0 +1,87 @@
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_switch_ring_avx(res: &mut [i64], a: &[i64]) {
unsafe {
use core::arch::x86_64::*;
let (n_in, n_out) = (a.len(), res.len());
#[cfg(debug_assertions)]
{
assert!(n_in.is_power_of_two());
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
}
if n_in == n_out {
use poulpy_hal::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
if n_in > n_out {
// Downsample: res[k] = a[k * gap_in], contiguous stores
let gap_in: usize = n_in / n_out;
// index vector: [0*gap, 1*gap, 2*gap, 3*gap] * gap_in
let step: __m256i = _mm256_setr_epi64x(0, gap_in as i64, 2 * gap_in as i64, 3 * gap_in as i64);
let span: usize = n_out >> 2;
let bump: __m256i = _mm256_set1_epi64x(4 * gap_in as i64);
let mut res_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let a_ptr: *const i64 = a.as_ptr();
let mut base: __m256i = _mm256_setzero_si256(); // starts at 0*gap
for _ in 0..span {
// idx = base + step
let idx: __m256i = _mm256_add_epi64(base, step);
// gather 4 spaced i64 (scale=8 bytes)
let v: __m256i = _mm256_i64gather_epi64(a_ptr, idx, 8);
// store contiguously
_mm256_storeu_si256(res_4xi64, v);
base = _mm256_add_epi64(base, bump);
res_4xi64 = res_4xi64.add(1);
}
} else {
// Upsample: res[k * gap_out] = a[k], i.e. res has holes;
use poulpy_hal::reference::znx::znx_zero_ref;
let gap_out = n_out / n_in;
// zero then scatter scalar stores
znx_zero_ref(res);
let mut a_4xi64: *const __m256i = a.as_ptr() as *const __m256i;
for i in (0..n_in).step_by(4) {
// Load contiguously 4 inputs
let v = _mm256_loadu_si256(a_4xi64);
// extract 4 lanes (pextrq). This is still the best we can do on AVX2.
let x0: i64 = _mm256_extract_epi64(v, 0);
let x1: i64 = _mm256_extract_epi64(v, 1);
let x2: i64 = _mm256_extract_epi64(v, 2);
let x3: i64 = _mm256_extract_epi64(v, 3);
// starting output pointer for this group
let mut p: *mut i64 = res.as_mut_ptr().add(i * gap_out);
// four strided stores with pointer bump (avoid mul each time)
*p = x0;
p = p.add(gap_out);
*p = x1;
p = p.add(gap_out);
*p = x2;
p = p.add(gap_out);
*p = x3;
a_4xi64 = a_4xi64.add(1)
}
}
}
}