Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -0,0 +1,271 @@
/// # Correctness
/// Ensured for inputs absolute value bounded by 2^50-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma")]
pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let n: usize = res.len();
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd,
};
let expo: f64 = (1i64 << 52) as f64;
let add_cst: i64 = 1i64 << 51;
let sub_cst: f64 = (3i64 << 51) as f64;
let expo_256: __m256d = _mm256_set1_pd(expo);
let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst);
let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst);
let mut res_ptr: *mut f64 = res.as_mut_ptr();
let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i;
let span: usize = n >> 2;
for _ in 0..span {
let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr);
ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256);
let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256);
af64_256 = _mm256_or_pd(af64_256, expo_256);
af64_256 = _mm256_sub_pd(af64_256, sub_cst_256);
_mm256_storeu_pd(res_ptr, af64_256);
res_ptr = res_ptr.add(4);
a_ptr = a_ptr.add(1);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref;
reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..])
}
}
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^63-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
#[allow(dead_code)]
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256 = _mm256_set1_pd(offset);
let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr: *const f64 = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
// a += sign(a) * m/2
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
// sign: either 0 or -1
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
// compute the exponents
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
// compute the new mantissa
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
// negate if the sign was negative
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
// stores
_mm256_storeu_si256(res_ptr, out);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^63-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) {
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref;
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256: __m256d = _mm256_set1_pd(offset);
let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64);
// a += sign(a) * m/2
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
// sign: either 0 or -1
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
// compute the exponents
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
// compute the new mantissa
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
// negate if the sign was negative
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
// stores
_mm256_storeu_si256(res_ptr_4xi64, out);
res_ptr_4xi64 = res_ptr_4xi64.add(1);
res_ptr_1xf64 = res_ptr_1xf64.add(4);
}
if !res.len().is_multiple_of(4) {
reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor)
}
}
println!();
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^50-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma")]
#[allow(dead_code)]
pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64,
};
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64;
let sub_cst: i64 = 1i64 << 51;
let add_cst: f64 = divisor * (3i64 << 51) as f64;
let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst);
let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst);
let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
a = _mm256_add_pd(a, add_cst_4);
let mut ai: __m256i = _mm256_castpd_si256(a);
ai = _mm256_and_si256(ai, mantissa_mask_4);
ai = _mm256_sub_epi64(ai, sub_cst_4);
// store the next value
_mm256_storeu_si256(res_ptr, ai);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}

View File

@@ -0,0 +1,162 @@
# ----------------------------------------------------------------------
# This kernel is a direct port of the FFT16 routine from spqlios-arithmetic
# (https://github.com/tfhe/spqlios-arithmetic)
# ----------------------------------------------------------------------
#
.text
.globl fft16_avx2_fma_asm
.hidden fft16_avx2_fma_asm
.p2align 4, 0x90
.type fft16_avx2_fma_asm,@function
fft16_avx2_fma_asm:
.att_syntax prefix
# SysV args: %rdi = re*, %rsi = im*, %rdx = omg*
# stage 0: load inputs
vmovupd (%rdi),%ymm0 # ra0
vmovupd 0x20(%rdi),%ymm1 # ra4
vmovupd 0x40(%rdi),%ymm2 # ra8
vmovupd 0x60(%rdi),%ymm3 # ra12
vmovupd (%rsi),%ymm4 # ia0
vmovupd 0x20(%rsi),%ymm5 # ia4
vmovupd 0x40(%rsi),%ymm6 # ia8
vmovupd 0x60(%rsi),%ymm7 # ia12
# stage 1
vmovupd (%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vmulpd %ymm6,%ymm13,%ymm8
vmulpd %ymm7,%ymm13,%ymm9
vmulpd %ymm2,%ymm13,%ymm10
vmulpd %ymm3,%ymm13,%ymm11
vfmsub231pd %ymm2,%ymm12,%ymm8
vfmsub231pd %ymm3,%ymm12,%ymm9
vfmadd231pd %ymm6,%ymm12,%ymm10
vfmadd231pd %ymm7,%ymm12,%ymm11
vsubpd %ymm8,%ymm0,%ymm2
vsubpd %ymm9,%ymm1,%ymm3
vsubpd %ymm10,%ymm4,%ymm6
vsubpd %ymm11,%ymm5,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vaddpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm4,%ymm4
vaddpd %ymm11,%ymm5,%ymm5
# stage 2
vmovupd 16(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vmulpd %ymm5,%ymm13,%ymm8
vmulpd %ymm7,%ymm12,%ymm9
vmulpd %ymm1,%ymm13,%ymm10
vmulpd %ymm3,%ymm12,%ymm11
vfmsub231pd %ymm1,%ymm12,%ymm8
vfmadd231pd %ymm3,%ymm13,%ymm9
vfmadd231pd %ymm5,%ymm12,%ymm10
vfmsub231pd %ymm7,%ymm13,%ymm11
vsubpd %ymm8,%ymm0,%ymm1
vaddpd %ymm9,%ymm2,%ymm3
vsubpd %ymm10,%ymm4,%ymm5
vaddpd %ymm11,%ymm6,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vsubpd %ymm9,%ymm2,%ymm2
vaddpd %ymm10,%ymm4,%ymm4
vsubpd %ymm11,%ymm6,%ymm6
# stage 3
vmovupd 0x20(%rdx),%ymm12
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3
vmulpd %ymm10,%ymm13,%ymm4
vmulpd %ymm11,%ymm12,%ymm5
vmulpd %ymm8,%ymm13,%ymm6
vmulpd %ymm9,%ymm12,%ymm7
vfmsub231pd %ymm8,%ymm12,%ymm4
vfmadd231pd %ymm9,%ymm13,%ymm5
vfmadd231pd %ymm10,%ymm12,%ymm6
vfmsub231pd %ymm11,%ymm13,%ymm7
vsubpd %ymm4,%ymm0,%ymm8
vaddpd %ymm5,%ymm1,%ymm9
vsubpd %ymm6,%ymm2,%ymm10
vaddpd %ymm7,%ymm3,%ymm11
vaddpd %ymm4,%ymm0,%ymm0
vsubpd %ymm5,%ymm1,%ymm1
vaddpd %ymm6,%ymm2,%ymm2
vsubpd %ymm7,%ymm3,%ymm3
# stage 4
vmovupd 0x40(%rdx),%ymm12
vmovupd 0x60(%rdx),%ymm13
vunpckhpd %ymm1,%ymm0,%ymm4
vunpckhpd %ymm3,%ymm2,%ymm6
vunpckhpd %ymm9,%ymm8,%ymm5
vunpckhpd %ymm11,%ymm10,%ymm7
vunpcklpd %ymm1,%ymm0,%ymm0
vunpcklpd %ymm3,%ymm2,%ymm2
vunpcklpd %ymm9,%ymm8,%ymm1
vunpcklpd %ymm11,%ymm10,%ymm3
vmulpd %ymm6,%ymm13,%ymm8
vmulpd %ymm7,%ymm12,%ymm9
vmulpd %ymm4,%ymm13,%ymm10
vmulpd %ymm5,%ymm12,%ymm11
vfmsub231pd %ymm4,%ymm12,%ymm8
vfmadd231pd %ymm5,%ymm13,%ymm9
vfmadd231pd %ymm6,%ymm12,%ymm10
vfmsub231pd %ymm7,%ymm13,%ymm11
vsubpd %ymm8,%ymm0,%ymm4
vaddpd %ymm9,%ymm1,%ymm5
vsubpd %ymm10,%ymm2,%ymm6
vaddpd %ymm11,%ymm3,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vsubpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm2,%ymm2
vsubpd %ymm11,%ymm3,%ymm3
vunpckhpd %ymm7,%ymm3,%ymm11
vunpckhpd %ymm5,%ymm1,%ymm9
vunpcklpd %ymm7,%ymm3,%ymm10
vunpcklpd %ymm5,%ymm1,%ymm8
vunpckhpd %ymm6,%ymm2,%ymm3
vunpckhpd %ymm4,%ymm0,%ymm1
vunpcklpd %ymm6,%ymm2,%ymm2
vunpcklpd %ymm4,%ymm0,%ymm0
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
# stores
vmovupd %ymm0,(%rdi) # ra0
vmovupd %ymm1,0x20(%rdi) # ra4
vmovupd %ymm2,0x40(%rdi) # ra8
vmovupd %ymm3,0x60(%rdi) # ra12
vmovupd %ymm4,(%rsi) # ia0
vmovupd %ymm5,0x20(%rsi) # ia4
vmovupd %ymm6,0x40(%rsi) # ia8
vmovupd %ymm7,0x60(%rsi) # ia12
vzeroupper
ret
.size fft16_avx2_fma_asm, .-fft16_avx2_fma_asm
.section .note.GNU-stack,"",@progbits

View File

@@ -0,0 +1,278 @@
use std::arch::x86_64::{
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
};
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
if m < 16 {
use poulpy_hal::reference::fft64::reim::fft_ref;
fft_ref(m, omg, data);
return;
}
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m == 16 {
fft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
} else if m <= 2048 {
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
fft_rec_16_avx2_fma(m, re, im, omg, 0);
}
}
unsafe extern "sysv64" {
unsafe fn fft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn fft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
unsafe {
fft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn fft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
if m <= 2048 {
return fft_bfs_16_avx2_fma(m, re, im, omg, pos);
};
let h: usize = m >> 1;
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
pos = fft_rec_16_avx2_fma(h, re, im, omg, pos);
pos = fft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
pos
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
if !log_m.is_multiple_of(2) {
let h: usize = mm >> 1;
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
mm = h
}
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
pos += 4;
}
mm = h
}
for off in (0..m).step_by(16) {
fft16_avx2_fma(
as_arr_mut::<16, f64>(&mut re[off..]),
as_arr_mut::<16, f64>(&mut im[off..]),
as_arr::<16, f64>(&omg[pos..]),
);
pos += 16;
}
pos
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn twiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
unsafe {
let omx: __m128d = _mm_load_pd(omg.as_ptr());
let omra: __m256d = _mm256_set_m128d(omx, omx);
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut tra: __m256d = _mm256_mul_pd(omi, ui1);
let mut tia: __m256d = _mm256_mul_pd(omi, ur1);
tra = _mm256_fmsub_pd(omr, ur1, tra);
tia = _mm256_fmadd_pd(omr, ui1, tia);
ur1 = _mm256_sub_pd(ur0, tra);
ui1 = _mm256_sub_pd(ui0, tia);
ur0 = _mm256_add_pd(ur0, tra);
ui0 = _mm256_add_pd(ui0, tia);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
r0 = r0.add(4);
r1 = r1.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
}
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn bitwiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
unsafe {
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ur2: __m256d = _mm256_loadu_pd(r2);
let mut ur3: __m256d = _mm256_loadu_pd(r3);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut ui2: __m256d = _mm256_loadu_pd(i2);
let mut ui3: __m256d = _mm256_loadu_pd(i3);
let mut tra: __m256d = _mm256_mul_pd(omai, ui2);
let mut trb: __m256d = _mm256_mul_pd(omai, ui3);
let mut tia: __m256d = _mm256_mul_pd(omai, ur2);
let mut tib: __m256d = _mm256_mul_pd(omai, ur3);
tra = _mm256_fmsub_pd(omar, ur2, tra);
trb = _mm256_fmsub_pd(omar, ur3, trb);
tia = _mm256_fmadd_pd(omar, ui2, tia);
tib = _mm256_fmadd_pd(omar, ui3, tib);
ur2 = _mm256_sub_pd(ur0, tra);
ur3 = _mm256_sub_pd(ur1, trb);
ui2 = _mm256_sub_pd(ui0, tia);
ui3 = _mm256_sub_pd(ui1, tib);
ur0 = _mm256_add_pd(ur0, tra);
ur1 = _mm256_add_pd(ur1, trb);
ui0 = _mm256_add_pd(ui0, tia);
ui1 = _mm256_add_pd(ui1, tib);
tra = _mm256_mul_pd(ombi, ui1);
trb = _mm256_mul_pd(ombr, ui3);
tia = _mm256_mul_pd(ombi, ur1);
tib = _mm256_mul_pd(ombr, ur3);
tra = _mm256_fmsub_pd(ombr, ur1, tra);
trb = _mm256_fmadd_pd(ombi, ur3, trb);
tia = _mm256_fmadd_pd(ombr, ui1, tia);
tib = _mm256_fmsub_pd(ombi, ui3, tib);
ur1 = _mm256_sub_pd(ur0, tra);
ur3 = _mm256_add_pd(ur2, trb);
ui1 = _mm256_sub_pd(ui0, tia);
ui3 = _mm256_add_pd(ui2, tib);
ur0 = _mm256_add_pd(ur0, tra);
ur2 = _mm256_sub_pd(ur2, trb);
ui0 = _mm256_add_pd(ui0, tia);
ui2 = _mm256_sub_pd(ui2, tib);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(r2, ur2);
_mm256_storeu_pd(r3, ur3);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
_mm256_storeu_pd(i2, ui2);
_mm256_storeu_pd(i3, ui3);
r0 = r0.add(4);
r1 = r1.add(4);
r2 = r2.add(4);
r3 = r3.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
i2 = i2.add(4);
i3 = i3.add(4);
}
}
}
#[test]
fn test_fft_avx2_fma() {
use super::*;
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn internal(log_m: usize) {
use poulpy_hal::reference::fft64::reim::ReimFFTRef;
let m = 1 << log_m;
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
}
}
if std::is_x86_feature_detected!("avx2") {
for log_m in 0..16 {
unsafe { internal(log_m) }
}
} else {
eprintln!("skipping: CPU lacks avx2");
}
}

View File

@@ -0,0 +1,350 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_add_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let mut bb: *const f64 = b.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let b_256: __m256d = _mm256_loadu_pd(bb);
_mm256_storeu_pd(rr, _mm256_add_pd(a_256, b_256));
rr = rr.add(4);
aa = aa.add(4);
bb = bb.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_add_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_add_pd(r_256, a_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let mut bb: *const f64 = b.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let b_256: __m256d = _mm256_loadu_pd(bb);
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, b_256));
rr = rr.add(4);
aa = aa.add(4);
bb = bb.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_sub_pd(r_256, a_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, r_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_negate_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
let span: usize = res.len() >> 2;
unsafe {
use std::arch::x86_64::_mm256_set1_pd;
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let neg0: __m256d = _mm256_set1_pd(-0.0);
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
_mm256_storeu_pd(rr, _mm256_xor_pd(a_256, neg0));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_negate_inplace_avx2_fma(res: &mut [f64]) {
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
let span: usize = res.len() >> 2;
unsafe {
use std::arch::x86_64::_mm256_set1_pd;
let mut rr: *mut f64 = res.as_mut_ptr();
let neg0: __m256d = _mm256_set1_pd(-0.0);
for _ in 0..span {
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_xor_pd(r_256, neg0));
rr = rr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_addmul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
let mut br_ptr: *const f64 = br.as_ptr();
let mut bi_ptr: *const f64 = bi.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let mut rr: __m256d = _mm256_loadu_pd(rr_ptr);
let mut ri: __m256d = _mm256_loadu_pd(ri_ptr);
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(br_ptr);
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
rr = _mm256_fmsub_pd(ai, bi, rr);
rr = _mm256_fmsub_pd(ar, br, rr);
ri = _mm256_fmadd_pd(ar, bi, ri);
ri = _mm256_fmadd_pd(ai, br, ri);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
br_ptr = br_ptr.add(4);
bi_ptr = bi_ptr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_mul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
let mut br_ptr: *const f64 = br.as_ptr();
let mut bi_ptr: *const f64 = bi.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(br_ptr);
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
let t1: __m256d = _mm256_mul_pd(ai, bi);
let t2: __m256d = _mm256_mul_pd(ar, bi);
let rr: __m256d = _mm256_fmsub_pd(ar, br, t1);
let ri: __m256d = _mm256_fmadd_pd(ai, br, t2);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
br_ptr = br_ptr.add(4);
bi_ptr = bi_ptr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
pub fn reim_mul_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(rr_ptr);
let bi: __m256d = _mm256_loadu_pd(ri_ptr);
let t1: __m256d = _mm256_mul_pd(ai, bi);
let t2: __m256d = _mm256_mul_pd(ar, bi);
let rr = _mm256_fmsub_pd(ar, br, t1);
let ri = _mm256_fmadd_pd(ai, br, t2);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
}
}
}

View File

@@ -0,0 +1,181 @@
# ----------------------------------------------------------------------
# This kernel is a direct port of the IFFT16 routine from spqlios-arithmetic
# (https://github.com/tfhe/spqlios-arithmetic)
# ----------------------------------------------------------------------
#
.text
.globl ifft16_avx2_fma_asm
.hidden ifft16_avx2_fma_asm
.p2align 4, 0x90
.type ifft16_avx2_fma_asm,@function
ifft16_avx2_fma_asm:
.att_syntax prefix
vmovupd (%rdi),%ymm0 # ra0
vmovupd 0x20(%rdi),%ymm1 # ra4
vmovupd 0x40(%rdi),%ymm2 # ra8
vmovupd 0x60(%rdi),%ymm3 # ra12
vmovupd (%rsi),%ymm4 # ia0
vmovupd 0x20(%rsi),%ymm5 # ia4
vmovupd 0x40(%rsi),%ymm6 # ia8
vmovupd 0x60(%rsi),%ymm7 # ia12
1:
vmovupd 0x00(%rdx),%ymm12
vmovupd 0x20(%rdx),%ymm13
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw)
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw)
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw)
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw)
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw)
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw)
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw)
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw)
vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4)
vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6)
vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5)
vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7)
vunpcklpd %ymm1,%ymm0,%ymm0
vunpcklpd %ymm3,%ymm2,%ymm2
vunpcklpd %ymm9,%ymm8,%ymm1
vunpcklpd %ymm11,%ymm10,%ymm3
# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm4,%ymm0,%ymm8 # retw
vsubpd %ymm5,%ymm1,%ymm9 # reitw
vsubpd %ymm6,%ymm2,%ymm10 # imtw
vsubpd %ymm7,%ymm3,%ymm11 # imitw
vaddpd %ymm4,%ymm0,%ymm0
vaddpd %ymm5,%ymm1,%ymm1
vaddpd %ymm6,%ymm2,%ymm2
vaddpd %ymm7,%ymm3,%ymm3
# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw)
vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw)
vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw)
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw)
vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw)
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw)
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1)
vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3)
vunpcklpd %ymm7,%ymm3,%ymm10
vunpcklpd %ymm5,%ymm1,%ymm8
vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9)
vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11)
vunpcklpd %ymm6,%ymm2,%ymm2
vunpcklpd %ymm4,%ymm0,%ymm0
2:
vmovupd 0x40(%rdx),%ymm12
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i'
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r'
# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13)
vsubpd %ymm8,%ymm0,%ymm4 # retw
vsubpd %ymm9,%ymm1,%ymm5 # reitw
vsubpd %ymm10,%ymm2,%ymm6 # imtw
vsubpd %ymm11,%ymm3,%ymm7 # imitw
vaddpd %ymm8,%ymm0,%ymm0
vaddpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm2,%ymm2
vaddpd %ymm11,%ymm3,%ymm3
# multiply 4,5,6,7 by 12,13, result to 8,9,10,11
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw)
vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw)
vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw)
vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw)
vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw)
vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw)
vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw)
vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw)
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
3:
vmovupd 0x60(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm1,%ymm0,%ymm8 # retw
vsubpd %ymm3,%ymm2,%ymm9 # reitw
vsubpd %ymm5,%ymm4,%ymm10 # imtw
vsubpd %ymm7,%ymm6,%ymm11 # imitw
vaddpd %ymm1,%ymm0,%ymm0
vaddpd %ymm3,%ymm2,%ymm2
vaddpd %ymm5,%ymm4,%ymm4
vaddpd %ymm7,%ymm6,%ymm6
# multiply 8,9,10,11 by 12,13, result to 1,3,5,7
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw)
vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw)
vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw)
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw)
vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw)
vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw)
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
4:
vmovupd 0x70(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13)
# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm2,%ymm0,%ymm8 # retw1
vsubpd %ymm3,%ymm1,%ymm9 # retw2
vsubpd %ymm6,%ymm4,%ymm10 # imtw1
vsubpd %ymm7,%ymm5,%ymm11 # imtw2
vaddpd %ymm2,%ymm0,%ymm0
vaddpd %ymm3,%ymm1,%ymm1
vaddpd %ymm6,%ymm4,%ymm4
vaddpd %ymm7,%ymm5,%ymm5
# multiply 8,9,10,11 by 12,13, result to 2,3,6,7
# twiddles use reom=ymm12, imom=ymm13
vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai
vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai
vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai
vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai
vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0
vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0
vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4
5:
vmovupd %ymm0,(%rdi) # ra0
vmovupd %ymm1,0x20(%rdi) # ra4
vmovupd %ymm2,0x40(%rdi) # ra8
vmovupd %ymm3,0x60(%rdi) # ra12
vmovupd %ymm4,(%rsi) # ia0
vmovupd %ymm5,0x20(%rsi) # ia4
vmovupd %ymm6,0x40(%rsi) # ia8
vmovupd %ymm7,0x60(%rsi) # ia12
vzeroupper
ret
.size ifft16_avx_fma, .-ifft16_avx_fma
.section .note.GNU-stack,"",@progbits

View File

@@ -0,0 +1,271 @@
use std::arch::x86_64::{
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
};
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
if m < 16 {
use poulpy_hal::reference::fft64::reim::ifft_ref;
ifft_ref(m, omg, data);
return;
}
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m == 16 {
ifft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
} else if m <= 2048 {
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
ifft_rec_16_avx2_fma(m, re, im, omg, 0);
}
}
unsafe extern "sysv64" {
unsafe fn ifft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn ifft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
unsafe {
ifft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
fn ifft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
if m <= 2048 {
return ifft_bfs_16_avx2_fma(m, re, im, omg, pos);
};
let h: usize = m >> 1;
pos = ifft_rec_16_avx2_fma(h, re, im, omg, pos);
pos = ifft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
pos
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
for off in (0..m).step_by(16) {
ifft16_avx2_fma(
as_arr_mut::<16, f64>(&mut re[off..]),
as_arr_mut::<16, f64>(&mut im[off..]),
as_arr::<16, f64>(&omg[pos..]),
);
pos += 16;
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
pos += 4;
}
h = mm;
}
if !log_m.is_multiple_of(2) {
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
}
pos
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn inv_twiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
unsafe {
let omx: __m128d = _mm_load_pd(omg.as_ptr());
let omra: __m256d = _mm256_set_m128d(omx, omx);
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let tra = _mm256_sub_pd(ur0, ur1);
let tia = _mm256_sub_pd(ui0, ui1);
ur0 = _mm256_add_pd(ur0, ur1);
ui0 = _mm256_add_pd(ui0, ui1);
ur1 = _mm256_mul_pd(omi, tia);
ui1 = _mm256_mul_pd(omi, tra);
ur1 = _mm256_fmsub_pd(omr, tra, ur1);
ui1 = _mm256_fmadd_pd(omr, tia, ui1);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
r0 = r0.add(4);
r1 = r1.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
}
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn inv_bitwiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
unsafe {
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ur2: __m256d = _mm256_loadu_pd(r2);
let mut ur3: __m256d = _mm256_loadu_pd(r3);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut ui2: __m256d = _mm256_loadu_pd(i2);
let mut ui3: __m256d = _mm256_loadu_pd(i3);
let mut tra: __m256d = _mm256_sub_pd(ur0, ur1);
let mut trb: __m256d = _mm256_sub_pd(ur2, ur3);
let mut tia: __m256d = _mm256_sub_pd(ui0, ui1);
let mut tib: __m256d = _mm256_sub_pd(ui2, ui3);
ur0 = _mm256_add_pd(ur0, ur1);
ur2 = _mm256_add_pd(ur2, ur3);
ui0 = _mm256_add_pd(ui0, ui1);
ui2 = _mm256_add_pd(ui2, ui3);
ur1 = _mm256_mul_pd(omai, tia);
ur3 = _mm256_mul_pd(omar, tib);
ui1 = _mm256_mul_pd(omai, tra);
ui3 = _mm256_mul_pd(omar, trb);
ur1 = _mm256_fmsub_pd(omar, tra, ur1);
ur3 = _mm256_fmadd_pd(omai, trb, ur3);
ui1 = _mm256_fmadd_pd(omar, tia, ui1);
ui3 = _mm256_fmsub_pd(omai, tib, ui3);
tra = _mm256_sub_pd(ur0, ur2);
trb = _mm256_sub_pd(ur1, ur3);
tia = _mm256_sub_pd(ui0, ui2);
tib = _mm256_sub_pd(ui1, ui3);
ur0 = _mm256_add_pd(ur0, ur2);
ur1 = _mm256_add_pd(ur1, ur3);
ui0 = _mm256_add_pd(ui0, ui2);
ui1 = _mm256_add_pd(ui1, ui3);
ur2 = _mm256_mul_pd(ombi, tia);
ur3 = _mm256_mul_pd(ombi, tib);
ui2 = _mm256_mul_pd(ombi, tra);
ui3 = _mm256_mul_pd(ombi, trb);
ur2 = _mm256_fmsub_pd(ombr, tra, ur2);
ur3 = _mm256_fmsub_pd(ombr, trb, ur3);
ui2 = _mm256_fmadd_pd(ombr, tia, ui2);
ui3 = _mm256_fmadd_pd(ombr, tib, ui3);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(r2, ur2);
_mm256_storeu_pd(r3, ur3);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
_mm256_storeu_pd(i2, ui2);
_mm256_storeu_pd(i3, ui3);
r0 = r0.add(4);
r1 = r1.add(4);
r2 = r2.add(4);
r3 = r3.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
i2 = i2.add(4);
i3 = i3.add(4);
}
}
}
#[test]
fn test_ifft_avx2_fma() {
use super::*;
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
fn internal(log_m: usize) {
use poulpy_hal::reference::fft64::reim::ReimIFFTRef;
let m: usize = 1 << log_m;
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
}
}
if std::is_x86_feature_detected!("avx2") {
for log_m in 0..16 {
unsafe { internal(log_m) }
}
} else {
eprintln!("skipping: CPU lacks avx2");
}
}

View File

@@ -0,0 +1,72 @@
// ----------------------------------------------------------------------
// DISCLAIMER
//
// This module contains code that has been directly ported from the
// spqlios-arithmetic library
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
// under the Apache License, Version 2.0.
//
// The porting process from C to Rust was done with minimal changes
// in order to preserve the semantics and performance characteristics
// of the original implementation.
//
// Both Poulpy and spqlios-arithmetic are distributed under the terms
// of the Apache License, Version 2.0. See the LICENSE file for details.
//
// ----------------------------------------------------------------------
#![allow(bad_asm_style)]
mod conversion;
mod fft_avx2_fma;
mod fft_vec_avx2_fma;
mod ifft_avx2_fma;
use std::arch::global_asm;
pub(crate) use conversion::*;
pub(crate) use fft_vec_avx2_fma::*;
use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTTable, ReimIFFTTable};
use rand_distr::num_traits::{Float, FloatConst};
use crate::cpu_fft64_avx::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
global_asm!(
include_str!("fft16_avx2_fma.s"),
include_str!("ifft16_avx2_fma.s")
);
#[inline(always)]
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {
debug_assert!(x.len() >= SIZE);
unsafe { &*(x.as_ptr() as *const [R; SIZE]) }
}
#[inline(always)]
pub(crate) fn as_arr_mut<const SIZE: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; SIZE] {
debug_assert!(x.len() >= SIZE);
unsafe { &mut *(x.as_mut_ptr() as *mut [R; SIZE]) }
}
pub struct ReimFFTAvx;
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTAvx {
#[inline(always)]
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
unsafe {
fft_avx2_fma(table.m(), table.omg(), data);
}
}
}
pub struct ReimIFFTAvx;
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTAvx {
#[inline(always)]
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
unsafe {
ifft_avx2_fma(table.m(), table.omg(), data);
}
}
}