mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
271
poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs
Normal file
271
poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
/// # Correctness
|
||||
/// Ensured for inputs absolute value bounded by 2^50-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "fma")]
|
||||
pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x,
|
||||
_mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd,
|
||||
};
|
||||
|
||||
let expo: f64 = (1i64 << 52) as f64;
|
||||
let add_cst: i64 = 1i64 << 51;
|
||||
let sub_cst: f64 = (3i64 << 51) as f64;
|
||||
|
||||
let expo_256: __m256d = _mm256_set1_pd(expo);
|
||||
let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst);
|
||||
let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst);
|
||||
|
||||
let mut res_ptr: *mut f64 = res.as_mut_ptr();
|
||||
let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr);
|
||||
|
||||
ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256);
|
||||
|
||||
let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256);
|
||||
af64_256 = _mm256_or_pd(af64_256, expo_256);
|
||||
af64_256 = _mm256_sub_pd(af64_256, sub_cst_256);
|
||||
|
||||
_mm256_storeu_pd(res_ptr, af64_256);
|
||||
|
||||
res_ptr = res_ptr.add(4);
|
||||
a_ptr = a_ptr.add(1);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref;
|
||||
reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^63-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
|
||||
#[allow(dead_code)]
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
let sign_mask: u64 = 0x8000000000000000u64;
|
||||
let expo_mask: u64 = 0x7FF0000000000000u64;
|
||||
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
|
||||
let mantissa_msb: u64 = 0x0010000000000000u64;
|
||||
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
|
||||
let offset: f64 = divisor / 2.;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
|
||||
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
|
||||
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
|
||||
};
|
||||
|
||||
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
|
||||
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
|
||||
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
|
||||
let offset_256 = _mm256_set1_pd(offset);
|
||||
let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
|
||||
|
||||
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut a_ptr: *const f64 = a.as_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
use std::arch::x86_64::_mm256_storeu_si256;
|
||||
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
|
||||
// a += sign(a) * m/2
|
||||
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
|
||||
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
|
||||
|
||||
// sign: either 0 or -1
|
||||
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
|
||||
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
|
||||
|
||||
// compute the exponents
|
||||
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
|
||||
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
|
||||
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
|
||||
a0lsh = _mm256_srli_epi64(a0lsh, 52);
|
||||
a0rsh = _mm256_srli_epi64(a0rsh, 52);
|
||||
|
||||
// compute the new mantissa
|
||||
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
|
||||
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
|
||||
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
|
||||
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
|
||||
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
|
||||
|
||||
// negate if the sign was negative
|
||||
out = _mm256_xor_si256(out, sign_mask);
|
||||
out = _mm256_sub_epi64(out, sign_mask);
|
||||
|
||||
// stores
|
||||
_mm256_storeu_si256(res_ptr, out);
|
||||
|
||||
res_ptr = res_ptr.add(1);
|
||||
a_ptr = a_ptr.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
|
||||
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^63-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) {
|
||||
let sign_mask: u64 = 0x8000000000000000u64;
|
||||
let expo_mask: u64 = 0x7FF0000000000000u64;
|
||||
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
|
||||
let mantissa_msb: u64 = 0x0010000000000000u64;
|
||||
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
|
||||
let offset: f64 = divisor / 2.;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
|
||||
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
|
||||
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
|
||||
};
|
||||
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref;
|
||||
|
||||
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
|
||||
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
|
||||
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
|
||||
let offset_256: __m256d = _mm256_set1_pd(offset);
|
||||
let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
|
||||
|
||||
let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
use std::arch::x86_64::_mm256_storeu_si256;
|
||||
let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64);
|
||||
|
||||
// a += sign(a) * m/2
|
||||
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
|
||||
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
|
||||
|
||||
// sign: either 0 or -1
|
||||
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
|
||||
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
|
||||
|
||||
// compute the exponents
|
||||
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
|
||||
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
|
||||
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
|
||||
a0lsh = _mm256_srli_epi64(a0lsh, 52);
|
||||
a0rsh = _mm256_srli_epi64(a0rsh, 52);
|
||||
|
||||
// compute the new mantissa
|
||||
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
|
||||
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
|
||||
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
|
||||
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
|
||||
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
|
||||
|
||||
// negate if the sign was negative
|
||||
out = _mm256_xor_si256(out, sign_mask);
|
||||
out = _mm256_sub_epi64(out, sign_mask);
|
||||
|
||||
// stores
|
||||
_mm256_storeu_si256(res_ptr_4xi64, out);
|
||||
|
||||
res_ptr_4xi64 = res_ptr_4xi64.add(1);
|
||||
res_ptr_1xf64 = res_ptr_1xf64.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor)
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^50-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "fma")]
|
||||
#[allow(dead_code)]
|
||||
pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x,
|
||||
_mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64;
|
||||
let sub_cst: i64 = 1i64 << 51;
|
||||
let add_cst: f64 = divisor * (3i64 << 51) as f64;
|
||||
|
||||
let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst);
|
||||
let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst);
|
||||
let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
|
||||
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut a_ptr = a.as_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
a = _mm256_add_pd(a, add_cst_4);
|
||||
let mut ai: __m256i = _mm256_castpd_si256(a);
|
||||
ai = _mm256_and_si256(ai, mantissa_mask_4);
|
||||
ai = _mm256_sub_epi64(ai, sub_cst_4);
|
||||
// store the next value
|
||||
_mm256_storeu_si256(res_ptr, ai);
|
||||
|
||||
res_ptr = res_ptr.add(1);
|
||||
a_ptr = a_ptr.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
|
||||
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
162
poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s
Normal file
162
poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s
Normal file
@@ -0,0 +1,162 @@
|
||||
# ----------------------------------------------------------------------
|
||||
# This kernel is a direct port of the FFT16 routine from spqlios-arithmetic
|
||||
# (https://github.com/tfhe/spqlios-arithmetic)
|
||||
# ----------------------------------------------------------------------
|
||||
#
|
||||
|
||||
.text
|
||||
.globl fft16_avx2_fma_asm
|
||||
.hidden fft16_avx2_fma_asm
|
||||
.p2align 4, 0x90
|
||||
.type fft16_avx2_fma_asm,@function
|
||||
fft16_avx2_fma_asm:
|
||||
.att_syntax prefix
|
||||
|
||||
# SysV args: %rdi = re*, %rsi = im*, %rdx = omg*
|
||||
# stage 0: load inputs
|
||||
vmovupd (%rdi),%ymm0 # ra0
|
||||
vmovupd 0x20(%rdi),%ymm1 # ra4
|
||||
vmovupd 0x40(%rdi),%ymm2 # ra8
|
||||
vmovupd 0x60(%rdi),%ymm3 # ra12
|
||||
vmovupd (%rsi),%ymm4 # ia0
|
||||
vmovupd 0x20(%rsi),%ymm5 # ia4
|
||||
vmovupd 0x40(%rsi),%ymm6 # ia8
|
||||
vmovupd 0x60(%rsi),%ymm7 # ia12
|
||||
|
||||
# stage 1
|
||||
vmovupd (%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
vmulpd %ymm6,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm13,%ymm9
|
||||
vmulpd %ymm2,%ymm13,%ymm10
|
||||
vmulpd %ymm3,%ymm13,%ymm11
|
||||
vfmsub231pd %ymm2,%ymm12,%ymm8
|
||||
vfmsub231pd %ymm3,%ymm12,%ymm9
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10
|
||||
vfmadd231pd %ymm7,%ymm12,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm2
|
||||
vsubpd %ymm9,%ymm1,%ymm3
|
||||
vsubpd %ymm10,%ymm4,%ymm6
|
||||
vsubpd %ymm11,%ymm5,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vaddpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm4,%ymm4
|
||||
vaddpd %ymm11,%ymm5,%ymm5
|
||||
|
||||
# stage 2
|
||||
vmovupd 16(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
vmulpd %ymm5,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm12,%ymm9
|
||||
vmulpd %ymm1,%ymm13,%ymm10
|
||||
vmulpd %ymm3,%ymm12,%ymm11
|
||||
vfmsub231pd %ymm1,%ymm12,%ymm8
|
||||
vfmadd231pd %ymm3,%ymm13,%ymm9
|
||||
vfmadd231pd %ymm5,%ymm12,%ymm10
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm1
|
||||
vaddpd %ymm9,%ymm2,%ymm3
|
||||
vsubpd %ymm10,%ymm4,%ymm5
|
||||
vaddpd %ymm11,%ymm6,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vsubpd %ymm9,%ymm2,%ymm2
|
||||
vaddpd %ymm10,%ymm4,%ymm4
|
||||
vsubpd %ymm11,%ymm6,%ymm6
|
||||
|
||||
# stage 3
|
||||
vmovupd 0x20(%rdx),%ymm12
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
|
||||
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8
|
||||
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9
|
||||
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10
|
||||
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11
|
||||
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1
|
||||
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2
|
||||
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3
|
||||
|
||||
vmulpd %ymm10,%ymm13,%ymm4
|
||||
vmulpd %ymm11,%ymm12,%ymm5
|
||||
vmulpd %ymm8,%ymm13,%ymm6
|
||||
vmulpd %ymm9,%ymm12,%ymm7
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm4
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm5
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7
|
||||
vsubpd %ymm4,%ymm0,%ymm8
|
||||
vaddpd %ymm5,%ymm1,%ymm9
|
||||
vsubpd %ymm6,%ymm2,%ymm10
|
||||
vaddpd %ymm7,%ymm3,%ymm11
|
||||
vaddpd %ymm4,%ymm0,%ymm0
|
||||
vsubpd %ymm5,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm2,%ymm2
|
||||
vsubpd %ymm7,%ymm3,%ymm3
|
||||
|
||||
# stage 4
|
||||
vmovupd 0x40(%rdx),%ymm12
|
||||
vmovupd 0x60(%rdx),%ymm13
|
||||
|
||||
vunpckhpd %ymm1,%ymm0,%ymm4
|
||||
vunpckhpd %ymm3,%ymm2,%ymm6
|
||||
vunpckhpd %ymm9,%ymm8,%ymm5
|
||||
vunpckhpd %ymm11,%ymm10,%ymm7
|
||||
vunpcklpd %ymm1,%ymm0,%ymm0
|
||||
vunpcklpd %ymm3,%ymm2,%ymm2
|
||||
vunpcklpd %ymm9,%ymm8,%ymm1
|
||||
vunpcklpd %ymm11,%ymm10,%ymm3
|
||||
|
||||
vmulpd %ymm6,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm12,%ymm9
|
||||
vmulpd %ymm4,%ymm13,%ymm10
|
||||
vmulpd %ymm5,%ymm12,%ymm11
|
||||
vfmsub231pd %ymm4,%ymm12,%ymm8
|
||||
vfmadd231pd %ymm5,%ymm13,%ymm9
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm4
|
||||
vaddpd %ymm9,%ymm1,%ymm5
|
||||
vsubpd %ymm10,%ymm2,%ymm6
|
||||
vaddpd %ymm11,%ymm3,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vsubpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm2,%ymm2
|
||||
vsubpd %ymm11,%ymm3,%ymm3
|
||||
|
||||
vunpckhpd %ymm7,%ymm3,%ymm11
|
||||
vunpckhpd %ymm5,%ymm1,%ymm9
|
||||
vunpcklpd %ymm7,%ymm3,%ymm10
|
||||
vunpcklpd %ymm5,%ymm1,%ymm8
|
||||
vunpckhpd %ymm6,%ymm2,%ymm3
|
||||
vunpckhpd %ymm4,%ymm0,%ymm1
|
||||
vunpcklpd %ymm6,%ymm2,%ymm2
|
||||
vunpcklpd %ymm4,%ymm0,%ymm0
|
||||
|
||||
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
|
||||
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
|
||||
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
|
||||
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
|
||||
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
|
||||
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
|
||||
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
|
||||
|
||||
# stores
|
||||
vmovupd %ymm0,(%rdi) # ra0
|
||||
vmovupd %ymm1,0x20(%rdi) # ra4
|
||||
vmovupd %ymm2,0x40(%rdi) # ra8
|
||||
vmovupd %ymm3,0x60(%rdi) # ra12
|
||||
vmovupd %ymm4,(%rsi) # ia0
|
||||
vmovupd %ymm5,0x20(%rsi) # ia4
|
||||
vmovupd %ymm6,0x40(%rsi) # ia8
|
||||
vmovupd %ymm7,0x60(%rsi) # ia12
|
||||
vzeroupper
|
||||
ret
|
||||
|
||||
.size fft16_avx2_fma_asm, .-fft16_avx2_fma_asm
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
278
poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs
Normal file
278
poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::arch::x86_64::{
|
||||
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
|
||||
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
if m < 16 {
|
||||
use poulpy_hal::reference::fft64::reim::fft_ref;
|
||||
|
||||
fft_ref(m, omg, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
fft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
fft_rec_16_avx2_fma(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "sysv64" {
|
||||
unsafe fn fft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
|
||||
unsafe {
|
||||
fft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fft_bfs_16_avx2_fma(m, re, im, omg, pos);
|
||||
};
|
||||
|
||||
let h: usize = m >> 1;
|
||||
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos = fft_rec_16_avx2_fma(h, re, im, omg, pos);
|
||||
pos = fft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h: usize = mm >> 1;
|
||||
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
mm = h
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
}
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
fft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(&mut re[off..]),
|
||||
as_arr_mut::<16, f64>(&mut im[off..]),
|
||||
as_arr::<16, f64>(&omg[pos..]),
|
||||
);
|
||||
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn twiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
|
||||
unsafe {
|
||||
let omx: __m128d = _mm_load_pd(omg.as_ptr());
|
||||
let omra: __m256d = _mm256_set_m128d(omx, omx);
|
||||
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
|
||||
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut tra: __m256d = _mm256_mul_pd(omi, ui1);
|
||||
let mut tia: __m256d = _mm256_mul_pd(omi, ur1);
|
||||
|
||||
tra = _mm256_fmsub_pd(omr, ur1, tra);
|
||||
tia = _mm256_fmadd_pd(omr, ui1, tia);
|
||||
ur1 = _mm256_sub_pd(ur0, tra);
|
||||
ui1 = _mm256_sub_pd(ui0, tia);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn bitwiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
|
||||
unsafe {
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
|
||||
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
|
||||
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
|
||||
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
|
||||
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
|
||||
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
|
||||
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
|
||||
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
|
||||
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
|
||||
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ur2: __m256d = _mm256_loadu_pd(r2);
|
||||
let mut ur3: __m256d = _mm256_loadu_pd(r3);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut ui2: __m256d = _mm256_loadu_pd(i2);
|
||||
let mut ui3: __m256d = _mm256_loadu_pd(i3);
|
||||
|
||||
let mut tra: __m256d = _mm256_mul_pd(omai, ui2);
|
||||
let mut trb: __m256d = _mm256_mul_pd(omai, ui3);
|
||||
let mut tia: __m256d = _mm256_mul_pd(omai, ur2);
|
||||
let mut tib: __m256d = _mm256_mul_pd(omai, ur3);
|
||||
tra = _mm256_fmsub_pd(omar, ur2, tra);
|
||||
trb = _mm256_fmsub_pd(omar, ur3, trb);
|
||||
tia = _mm256_fmadd_pd(omar, ui2, tia);
|
||||
tib = _mm256_fmadd_pd(omar, ui3, tib);
|
||||
ur2 = _mm256_sub_pd(ur0, tra);
|
||||
ur3 = _mm256_sub_pd(ur1, trb);
|
||||
ui2 = _mm256_sub_pd(ui0, tia);
|
||||
ui3 = _mm256_sub_pd(ui1, tib);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ur1 = _mm256_add_pd(ur1, trb);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
ui1 = _mm256_add_pd(ui1, tib);
|
||||
|
||||
tra = _mm256_mul_pd(ombi, ui1);
|
||||
trb = _mm256_mul_pd(ombr, ui3);
|
||||
tia = _mm256_mul_pd(ombi, ur1);
|
||||
tib = _mm256_mul_pd(ombr, ur3);
|
||||
tra = _mm256_fmsub_pd(ombr, ur1, tra);
|
||||
trb = _mm256_fmadd_pd(ombi, ur3, trb);
|
||||
tia = _mm256_fmadd_pd(ombr, ui1, tia);
|
||||
tib = _mm256_fmsub_pd(ombi, ui3, tib);
|
||||
ur1 = _mm256_sub_pd(ur0, tra);
|
||||
ur3 = _mm256_add_pd(ur2, trb);
|
||||
ui1 = _mm256_sub_pd(ui0, tia);
|
||||
ui3 = _mm256_add_pd(ui2, tib);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ur2 = _mm256_sub_pd(ur2, trb);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
ui2 = _mm256_sub_pd(ui2, tib);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(r2, ur2);
|
||||
_mm256_storeu_pd(r3, ur3);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
_mm256_storeu_pd(i2, ui2);
|
||||
_mm256_storeu_pd(i3, ui3);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
r2 = r2.add(4);
|
||||
r3 = r3.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
i2 = i2.add(4);
|
||||
i3 = i3.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_avx2_fma() {
|
||||
use super::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn internal(log_m: usize) {
|
||||
use poulpy_hal::reference::fft64::reim::ReimFFTRef;
|
||||
|
||||
let m = 1 << log_m;
|
||||
|
||||
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
|
||||
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in 0..16 {
|
||||
unsafe { internal(log_m) }
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
}
|
||||
}
|
||||
350
poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs
Normal file
350
poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_add_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
let mut bb: *const f64 = b.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let b_256: __m256d = _mm256_loadu_pd(bb);
|
||||
_mm256_storeu_pd(rr, _mm256_add_pd(a_256, b_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
bb = bb.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_add_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_add_pd(r_256, a_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
let mut bb: *const f64 = b.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let b_256: __m256d = _mm256_loadu_pd(bb);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, b_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
bb = bb.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(r_256, a_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, r_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_negate_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::_mm256_set1_pd;
|
||||
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
let neg0: __m256d = _mm256_set1_pd(-0.0);
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
_mm256_storeu_pd(rr, _mm256_xor_pd(a_256, neg0));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_negate_inplace_avx2_fma(res: &mut [f64]) {
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::_mm256_set1_pd;
|
||||
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let neg0: __m256d = _mm256_set1_pd(-0.0);
|
||||
|
||||
for _ in 0..span {
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_xor_pd(r_256, neg0));
|
||||
rr = rr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_addmul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
let mut br_ptr: *const f64 = br.as_ptr();
|
||||
let mut bi_ptr: *const f64 = bi.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let mut rr: __m256d = _mm256_loadu_pd(rr_ptr);
|
||||
let mut ri: __m256d = _mm256_loadu_pd(ri_ptr);
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(br_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
|
||||
|
||||
rr = _mm256_fmsub_pd(ai, bi, rr);
|
||||
rr = _mm256_fmsub_pd(ar, br, rr);
|
||||
ri = _mm256_fmadd_pd(ar, bi, ri);
|
||||
ri = _mm256_fmadd_pd(ai, br, ri);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
br_ptr = br_ptr.add(4);
|
||||
bi_ptr = bi_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_mul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
let mut br_ptr: *const f64 = br.as_ptr();
|
||||
let mut bi_ptr: *const f64 = bi.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(br_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
|
||||
|
||||
let t1: __m256d = _mm256_mul_pd(ai, bi);
|
||||
let t2: __m256d = _mm256_mul_pd(ar, bi);
|
||||
|
||||
let rr: __m256d = _mm256_fmsub_pd(ar, br, t1);
|
||||
let ri: __m256d = _mm256_fmadd_pd(ai, br, t2);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
br_ptr = br_ptr.add(4);
|
||||
bi_ptr = bi_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_mul_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(rr_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(ri_ptr);
|
||||
|
||||
let t1: __m256d = _mm256_mul_pd(ai, bi);
|
||||
let t2: __m256d = _mm256_mul_pd(ar, bi);
|
||||
|
||||
let rr = _mm256_fmsub_pd(ar, br, t1);
|
||||
let ri = _mm256_fmadd_pd(ai, br, t2);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
181
poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s
Normal file
181
poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s
Normal file
@@ -0,0 +1,181 @@
|
||||
# ----------------------------------------------------------------------
|
||||
# This kernel is a direct port of the IFFT16 routine from spqlios-arithmetic
|
||||
# (https://github.com/tfhe/spqlios-arithmetic)
|
||||
# ----------------------------------------------------------------------
|
||||
#
|
||||
|
||||
.text
|
||||
.globl ifft16_avx2_fma_asm
|
||||
.hidden ifft16_avx2_fma_asm
|
||||
.p2align 4, 0x90
|
||||
.type ifft16_avx2_fma_asm,@function
|
||||
ifft16_avx2_fma_asm:
|
||||
.att_syntax prefix
|
||||
|
||||
vmovupd (%rdi),%ymm0 # ra0
|
||||
vmovupd 0x20(%rdi),%ymm1 # ra4
|
||||
vmovupd 0x40(%rdi),%ymm2 # ra8
|
||||
vmovupd 0x60(%rdi),%ymm3 # ra12
|
||||
vmovupd (%rsi),%ymm4 # ia0
|
||||
vmovupd 0x20(%rsi),%ymm5 # ia4
|
||||
vmovupd 0x40(%rsi),%ymm6 # ia8
|
||||
vmovupd 0x60(%rsi),%ymm7 # ia12
|
||||
|
||||
1:
|
||||
vmovupd 0x00(%rdx),%ymm12
|
||||
vmovupd 0x20(%rdx),%ymm13
|
||||
|
||||
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw)
|
||||
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw)
|
||||
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw)
|
||||
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw)
|
||||
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw)
|
||||
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw)
|
||||
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw)
|
||||
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw)
|
||||
|
||||
vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4)
|
||||
vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6)
|
||||
vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5)
|
||||
vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7)
|
||||
vunpcklpd %ymm1,%ymm0,%ymm0
|
||||
vunpcklpd %ymm3,%ymm2,%ymm2
|
||||
vunpcklpd %ymm9,%ymm8,%ymm1
|
||||
vunpcklpd %ymm11,%ymm10,%ymm3
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm4,%ymm0,%ymm8 # retw
|
||||
vsubpd %ymm5,%ymm1,%ymm9 # reitw
|
||||
vsubpd %ymm6,%ymm2,%ymm10 # imtw
|
||||
vsubpd %ymm7,%ymm3,%ymm11 # imitw
|
||||
vaddpd %ymm4,%ymm0,%ymm0
|
||||
vaddpd %ymm5,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm2,%ymm2
|
||||
vaddpd %ymm7,%ymm3,%ymm3
|
||||
# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw)
|
||||
vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw)
|
||||
vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw)
|
||||
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw)
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw)
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw)
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
|
||||
|
||||
vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1)
|
||||
vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3)
|
||||
vunpcklpd %ymm7,%ymm3,%ymm10
|
||||
vunpcklpd %ymm5,%ymm1,%ymm8
|
||||
vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9)
|
||||
vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11)
|
||||
vunpcklpd %ymm6,%ymm2,%ymm2
|
||||
vunpcklpd %ymm4,%ymm0,%ymm0
|
||||
|
||||
2:
|
||||
vmovupd 0x40(%rdx),%ymm12
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i'
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r'
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm8,%ymm0,%ymm4 # retw
|
||||
vsubpd %ymm9,%ymm1,%ymm5 # reitw
|
||||
vsubpd %ymm10,%ymm2,%ymm6 # imtw
|
||||
vsubpd %ymm11,%ymm3,%ymm7 # imitw
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vaddpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm2,%ymm2
|
||||
vaddpd %ymm11,%ymm3,%ymm3
|
||||
# multiply 4,5,6,7 by 12,13, result to 8,9,10,11
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw)
|
||||
vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw)
|
||||
vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw)
|
||||
vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw)
|
||||
vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw)
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw)
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw)
|
||||
|
||||
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
|
||||
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
|
||||
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
|
||||
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
|
||||
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
|
||||
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
|
||||
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
|
||||
|
||||
3:
|
||||
vmovupd 0x60(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm1,%ymm0,%ymm8 # retw
|
||||
vsubpd %ymm3,%ymm2,%ymm9 # reitw
|
||||
vsubpd %ymm5,%ymm4,%ymm10 # imtw
|
||||
vsubpd %ymm7,%ymm6,%ymm11 # imitw
|
||||
vaddpd %ymm1,%ymm0,%ymm0
|
||||
vaddpd %ymm3,%ymm2,%ymm2
|
||||
vaddpd %ymm5,%ymm4,%ymm4
|
||||
vaddpd %ymm7,%ymm6,%ymm6
|
||||
# multiply 8,9,10,11 by 12,13, result to 1,3,5,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw)
|
||||
vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw)
|
||||
vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw)
|
||||
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw)
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw)
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw)
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
|
||||
|
||||
4:
|
||||
vmovupd 0x70(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13)
|
||||
# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm2,%ymm0,%ymm8 # retw1
|
||||
vsubpd %ymm3,%ymm1,%ymm9 # retw2
|
||||
vsubpd %ymm6,%ymm4,%ymm10 # imtw1
|
||||
vsubpd %ymm7,%ymm5,%ymm11 # imtw2
|
||||
vaddpd %ymm2,%ymm0,%ymm0
|
||||
vaddpd %ymm3,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm4,%ymm4
|
||||
vaddpd %ymm7,%ymm5,%ymm5
|
||||
# multiply 8,9,10,11 by 12,13, result to 2,3,6,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai
|
||||
vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai
|
||||
vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai
|
||||
vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0
|
||||
vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0
|
||||
vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4
|
||||
|
||||
5:
|
||||
vmovupd %ymm0,(%rdi) # ra0
|
||||
vmovupd %ymm1,0x20(%rdi) # ra4
|
||||
vmovupd %ymm2,0x40(%rdi) # ra8
|
||||
vmovupd %ymm3,0x60(%rdi) # ra12
|
||||
vmovupd %ymm4,(%rsi) # ia0
|
||||
vmovupd %ymm5,0x20(%rsi) # ia4
|
||||
vmovupd %ymm6,0x40(%rsi) # ia8
|
||||
vmovupd %ymm7,0x60(%rsi) # ia12
|
||||
vzeroupper
|
||||
ret
|
||||
|
||||
.size ifft16_avx_fma, .-ifft16_avx_fma
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
271
poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs
Normal file
271
poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use std::arch::x86_64::{
|
||||
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
|
||||
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
if m < 16 {
|
||||
use poulpy_hal::reference::fft64::reim::ifft_ref;
|
||||
ifft_ref(m, omg, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
ifft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
ifft_rec_16_avx2_fma(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "sysv64" {
|
||||
unsafe fn ifft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
|
||||
unsafe {
|
||||
ifft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return ifft_bfs_16_avx2_fma(m, re, im, omg, pos);
|
||||
};
|
||||
let h: usize = m >> 1;
|
||||
pos = ifft_rec_16_avx2_fma(h, re, im, omg, pos);
|
||||
pos = ifft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
ifft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(&mut re[off..]),
|
||||
as_arr_mut::<16, f64>(&mut im[off..]),
|
||||
as_arr::<16, f64>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn inv_twiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
|
||||
unsafe {
|
||||
let omx: __m128d = _mm_load_pd(omg.as_ptr());
|
||||
let omra: __m256d = _mm256_set_m128d(omx, omx);
|
||||
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
|
||||
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let tra = _mm256_sub_pd(ur0, ur1);
|
||||
let tia = _mm256_sub_pd(ui0, ui1);
|
||||
ur0 = _mm256_add_pd(ur0, ur1);
|
||||
ui0 = _mm256_add_pd(ui0, ui1);
|
||||
ur1 = _mm256_mul_pd(omi, tia);
|
||||
ui1 = _mm256_mul_pd(omi, tra);
|
||||
ur1 = _mm256_fmsub_pd(omr, tra, ur1);
|
||||
ui1 = _mm256_fmadd_pd(omr, tia, ui1);
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn inv_bitwiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
|
||||
unsafe {
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
|
||||
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
|
||||
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
|
||||
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
|
||||
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
|
||||
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
|
||||
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
|
||||
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
|
||||
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
|
||||
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ur2: __m256d = _mm256_loadu_pd(r2);
|
||||
let mut ur3: __m256d = _mm256_loadu_pd(r3);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut ui2: __m256d = _mm256_loadu_pd(i2);
|
||||
let mut ui3: __m256d = _mm256_loadu_pd(i3);
|
||||
|
||||
let mut tra: __m256d = _mm256_sub_pd(ur0, ur1);
|
||||
let mut trb: __m256d = _mm256_sub_pd(ur2, ur3);
|
||||
let mut tia: __m256d = _mm256_sub_pd(ui0, ui1);
|
||||
let mut tib: __m256d = _mm256_sub_pd(ui2, ui3);
|
||||
ur0 = _mm256_add_pd(ur0, ur1);
|
||||
ur2 = _mm256_add_pd(ur2, ur3);
|
||||
ui0 = _mm256_add_pd(ui0, ui1);
|
||||
ui2 = _mm256_add_pd(ui2, ui3);
|
||||
ur1 = _mm256_mul_pd(omai, tia);
|
||||
ur3 = _mm256_mul_pd(omar, tib);
|
||||
ui1 = _mm256_mul_pd(omai, tra);
|
||||
ui3 = _mm256_mul_pd(omar, trb);
|
||||
ur1 = _mm256_fmsub_pd(omar, tra, ur1);
|
||||
ur3 = _mm256_fmadd_pd(omai, trb, ur3);
|
||||
ui1 = _mm256_fmadd_pd(omar, tia, ui1);
|
||||
ui3 = _mm256_fmsub_pd(omai, tib, ui3);
|
||||
|
||||
tra = _mm256_sub_pd(ur0, ur2);
|
||||
trb = _mm256_sub_pd(ur1, ur3);
|
||||
tia = _mm256_sub_pd(ui0, ui2);
|
||||
tib = _mm256_sub_pd(ui1, ui3);
|
||||
ur0 = _mm256_add_pd(ur0, ur2);
|
||||
ur1 = _mm256_add_pd(ur1, ur3);
|
||||
ui0 = _mm256_add_pd(ui0, ui2);
|
||||
ui1 = _mm256_add_pd(ui1, ui3);
|
||||
ur2 = _mm256_mul_pd(ombi, tia);
|
||||
ur3 = _mm256_mul_pd(ombi, tib);
|
||||
ui2 = _mm256_mul_pd(ombi, tra);
|
||||
ui3 = _mm256_mul_pd(ombi, trb);
|
||||
ur2 = _mm256_fmsub_pd(ombr, tra, ur2);
|
||||
ur3 = _mm256_fmsub_pd(ombr, trb, ur3);
|
||||
ui2 = _mm256_fmadd_pd(ombr, tia, ui2);
|
||||
ui3 = _mm256_fmadd_pd(ombr, tib, ui3);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(r2, ur2);
|
||||
_mm256_storeu_pd(r3, ur3);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
_mm256_storeu_pd(i2, ui2);
|
||||
_mm256_storeu_pd(i3, ui3);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
r2 = r2.add(4);
|
||||
r3 = r3.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
i2 = i2.add(4);
|
||||
i3 = i3.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_avx2_fma() {
|
||||
use super::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn internal(log_m: usize) {
|
||||
use poulpy_hal::reference::fft64::reim::ReimIFFTRef;
|
||||
|
||||
let m: usize = 1 << log_m;
|
||||
|
||||
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
|
||||
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in 0..16 {
|
||||
unsafe { internal(log_m) }
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
}
|
||||
}
|
||||
72
poulpy-backend/src/cpu_fft64_avx/reim/mod.rs
Normal file
72
poulpy-backend/src/cpu_fft64_avx/reim/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// ----------------------------------------------------------------------
|
||||
// DISCLAIMER
|
||||
//
|
||||
// This module contains code that has been directly ported from the
|
||||
// spqlios-arithmetic library
|
||||
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
|
||||
// under the Apache License, Version 2.0.
|
||||
//
|
||||
// The porting process from C to Rust was done with minimal changes
|
||||
// in order to preserve the semantics and performance characteristics
|
||||
// of the original implementation.
|
||||
//
|
||||
// Both Poulpy and spqlios-arithmetic are distributed under the terms
|
||||
// of the Apache License, Version 2.0. See the LICENSE file for details.
|
||||
//
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#![allow(bad_asm_style)]
|
||||
|
||||
mod conversion;
|
||||
mod fft_avx2_fma;
|
||||
mod fft_vec_avx2_fma;
|
||||
mod ifft_avx2_fma;
|
||||
|
||||
use std::arch::global_asm;
|
||||
|
||||
pub(crate) use conversion::*;
|
||||
pub(crate) use fft_vec_avx2_fma::*;
|
||||
|
||||
use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTTable, ReimIFFTTable};
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
|
||||
|
||||
global_asm!(
|
||||
include_str!("fft16_avx2_fma.s"),
|
||||
include_str!("ifft16_avx2_fma.s")
|
||||
);
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {
|
||||
debug_assert!(x.len() >= SIZE);
|
||||
unsafe { &*(x.as_ptr() as *const [R; SIZE]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_mut<const SIZE: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; SIZE] {
|
||||
debug_assert!(x.len() >= SIZE);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [R; SIZE]) }
|
||||
}
|
||||
|
||||
pub struct ReimFFTAvx;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTAvx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
unsafe {
|
||||
fft_avx2_fma(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimIFFTAvx;
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTAvx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
unsafe {
|
||||
ifft_avx2_fma(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user