Backend refactor (#120)

* remove spqlios, split cpu_ref and cpu_avx into different crates

* remove spqlios submodule

* update crate naming & add avx tests
This commit is contained in:
Jean-Philippe Bossuat
2025-11-19 15:34:31 +01:00
committed by GitHub
parent 84598e42fe
commit 9e007c988f
182 changed files with 1053 additions and 4483 deletions

16
poulpy-cpu-avx/src/lib.rs Normal file
View File

@@ -0,0 +1,16 @@
mod module;
mod reim;
mod reim4;
mod scratch;
mod svp;
mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp;
mod znx_avx;
pub struct FFT64Avx {}
pub use reim::*;
#[cfg(test)]
pub mod tests;

View File

@@ -0,0 +1,525 @@
use std::ptr::NonNull;
use poulpy_hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
reference::{
fft64::{
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
},
reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
},
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub,
ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
},
},
};
use crate::{
FFT64Avx,
reim::{
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
reim_sub_inplace_avx2_fma, reim_sub_negate_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
},
reim_to_znx_i64_bnd63_avx2_fma,
reim4::{
reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx,
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
},
znx_avx::{
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
znx_mul_power_of_two_avx, znx_mul_power_of_two_inplace_avx, znx_negate_avx, znx_negate_inplace_avx,
znx_normalize_digit_avx, znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx,
znx_normalize_first_step_avx, znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx,
znx_normalize_middle_step_avx, znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx,
znx_sub_avx, znx_sub_inplace_avx, znx_sub_negate_inplace_avx, znx_switch_ring_avx,
},
};
#[repr(C)]
pub struct FFT64AvxHandle {
table_fft: ReimFFTTable<f64>,
table_ifft: ReimIFFTTable<f64>,
}
impl Backend for FFT64Avx {
type ScalarPrep = f64;
type ScalarBig = i64;
type Handle = FFT64AvxHandle;
unsafe fn destroy(handle: NonNull<Self::Handle>) {
unsafe {
drop(Box::from_raw(handle.as_ptr()));
}
}
fn layout_big_word_count() -> usize {
1
}
fn layout_prep_word_count() -> usize {
1
}
}
unsafe impl ModuleNewImpl<Self> for FFT64Avx {
fn new_impl(n: u64) -> Module<Self> {
if !std::arch::is_x86_feature_detected!("avx")
|| !std::arch::is_x86_feature_detected!("avx2")
|| !std::arch::is_x86_feature_detected!("fma")
{
panic!("arch must support avx2, avx and fma")
}
let handle: FFT64AvxHandle = FFT64AvxHandle {
table_fft: ReimFFTTable::new(n as usize >> 1),
table_ifft: ReimIFFTTable::new(n as usize >> 1),
};
// Leak Box to get a stable NonNull pointer
let ptr: NonNull<FFT64AvxHandle> = NonNull::from(Box::leak(Box::new(handle)));
unsafe { Module::from_nonnull(ptr, n) }
}
}
pub trait FFT64ModuleHandle {
fn get_fft_table(&self) -> &ReimFFTTable<f64>;
fn get_ifft_table(&self) -> &ReimIFFTTable<f64>;
}
impl FFT64ModuleHandle for Module<FFT64Avx> {
fn get_fft_table(&self) -> &ReimFFTTable<f64> {
let h: &FFT64AvxHandle = unsafe { &*self.ptr() };
&h.table_fft
}
fn get_ifft_table(&self) -> &ReimIFFTTable<f64> {
let h: &FFT64AvxHandle = unsafe { &*self.ptr() };
&h.table_ifft
}
}
impl ZnxAdd for FFT64Avx {
#[inline(always)]
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
unsafe {
znx_add_avx(res, a, b);
}
}
}
impl ZnxAddInplace for FFT64Avx {
#[inline(always)]
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
unsafe {
znx_add_inplace_avx(res, a);
}
}
}
impl ZnxSub for FFT64Avx {
#[inline(always)]
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
unsafe {
znx_sub_avx(res, a, b);
}
}
}
impl ZnxSubInplace for FFT64Avx {
#[inline(always)]
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
unsafe {
znx_sub_inplace_avx(res, a);
}
}
}
impl ZnxSubNegateInplace for FFT64Avx {
#[inline(always)]
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
unsafe {
znx_sub_negate_inplace_avx(res, a);
}
}
}
impl ZnxAutomorphism for FFT64Avx {
#[inline(always)]
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
unsafe {
znx_automorphism_avx(p, res, a);
}
}
}
impl ZnxCopy for FFT64Avx {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxNegate for FFT64Avx {
#[inline(always)]
fn znx_negate(res: &mut [i64], src: &[i64]) {
unsafe {
znx_negate_avx(res, src);
}
}
}
impl ZnxNegateInplace for FFT64Avx {
#[inline(always)]
fn znx_negate_inplace(res: &mut [i64]) {
unsafe {
znx_negate_inplace_avx(res);
}
}
}
impl ZnxMulAddPowerOfTwo for FFT64Avx {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
unsafe {
znx_mul_add_power_of_two_avx(k, res, a);
}
}
}
impl ZnxMulPowerOfTwo for FFT64Avx {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
unsafe {
znx_mul_power_of_two_avx(k, res, a);
}
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Avx {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
unsafe {
znx_mul_power_of_two_inplace_avx(k, res);
}
}
}
impl ZnxRotate for FFT64Avx {
#[inline(always)]
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
znx_rotate::<Self>(p, res, src);
}
}
impl ZnxZero for FFT64Avx {
#[inline(always)]
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for FFT64Avx {
#[inline(always)]
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
unsafe {
znx_switch_ring_avx(res, a);
}
}
}
impl ZnxNormalizeFinalStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_final_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_final_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeFirstStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_carry_only_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeMiddleStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_carry_only_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxExtractDigitAddMul for FFT64Avx {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
unsafe {
znx_extract_digit_addmul_avx(base2k, lsh, res, src);
}
}
}
impl ZnxNormalizeDigit for FFT64Avx {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
unsafe {
znx_normalize_digit_avx(base2k, res, src);
}
}
}
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for FFT64Avx {
#[inline(always)]
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
ReimFFTAvx::reim_dft_execute(table, data);
}
}
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for FFT64Avx {
#[inline(always)]
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
ReimIFFTAvx::reim_dft_execute(table, data);
}
}
impl ReimFromZnx for FFT64Avx {
#[inline(always)]
fn reim_from_znx(res: &mut [f64], a: &[i64]) {
unsafe {
reim_from_znx_i64_bnd50_fma(res, a);
}
}
}
impl ReimToZnx for FFT64Avx {
#[inline(always)]
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]) {
unsafe {
reim_to_znx_i64_bnd63_avx2_fma(res, divisor, a);
}
}
}
impl ReimToZnxInplace for FFT64Avx {
#[inline(always)]
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64) {
unsafe {
reim_to_znx_i64_inplace_bnd63_avx2_fma(res, divisor);
}
}
}
impl ReimAdd for FFT64Avx {
#[inline(always)]
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]) {
unsafe {
reim_add_avx2_fma(res, a, b);
}
}
}
impl ReimAddInplace for FFT64Avx {
#[inline(always)]
fn reim_add_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_add_inplace_avx2_fma(res, a);
}
}
}
impl ReimSub for FFT64Avx {
#[inline(always)]
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]) {
unsafe {
reim_sub_avx2_fma(res, a, b);
}
}
}
impl ReimSubInplace for FFT64Avx {
#[inline(always)]
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_sub_inplace_avx2_fma(res, a);
}
}
}
impl ReimSubNegateInplace for FFT64Avx {
#[inline(always)]
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_sub_negate_inplace_avx2_fma(res, a);
}
}
}
impl ReimNegate for FFT64Avx {
#[inline(always)]
fn reim_negate(res: &mut [f64], a: &[f64]) {
unsafe {
reim_negate_avx2_fma(res, a);
}
}
}
impl ReimNegateInplace for FFT64Avx {
#[inline(always)]
fn reim_negate_inplace(res: &mut [f64]) {
unsafe {
reim_negate_inplace_avx2_fma(res);
}
}
}
impl ReimMul for FFT64Avx {
#[inline(always)]
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]) {
unsafe {
reim_mul_avx2_fma(res, a, b);
}
}
}
impl ReimMulInplace for FFT64Avx {
#[inline(always)]
fn reim_mul_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_mul_inplace_avx2_fma(res, a);
}
}
}
impl ReimAddMul for FFT64Avx {
#[inline(always)]
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]) {
unsafe {
reim_addmul_avx2_fma(res, a, b);
}
}
}
impl ReimCopy for FFT64Avx {
#[inline(always)]
fn reim_copy(res: &mut [f64], a: &[f64]) {
reim_copy_ref(res, a);
}
}
impl ReimZero for FFT64Avx {
#[inline(always)]
fn reim_zero(res: &mut [f64]) {
reim_zero_ref(res);
}
}
impl Reim4Extract1Blk for FFT64Avx {
#[inline(always)]
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src);
}
}
}
impl Reim4Save1Blk for FFT64Avx {
#[inline(always)]
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_save_1blk_to_reim_avx::<OVERWRITE>(m, blk, dst, src);
}
}
}
impl Reim4Save2Blks for FFT64Avx {
#[inline(always)]
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_save_2blk_to_reim_avx::<OVERWRITE>(m, blk, dst, src);
}
}
}
impl Reim4Mat1ColProd for FFT64Avx {
#[inline(always)]
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
unsafe {
reim4_vec_mat1col_product_avx(nrows, dst, u, v);
}
}
}
impl Reim4Mat2ColsProd for FFT64Avx {
#[inline(always)]
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
unsafe {
reim4_vec_mat2cols_product_avx(nrows, dst, u, v);
}
}
}
impl Reim4Mat2Cols2ndColProd for FFT64Avx {
#[inline(always)]
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
unsafe {
reim4_vec_mat2cols_2ndcol_product_avx(nrows, dst, u, v);
}
}
}

View File

@@ -0,0 +1,266 @@
/// # Correctness
/// Ensured for inputs absolute value bounded by 2^50-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
#[target_feature(enable = "fma")]
pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let n: usize = res.len();
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd,
};
let expo: f64 = (1i64 << 52) as f64;
let add_cst: i64 = 1i64 << 51;
let sub_cst: f64 = (3i64 << 51) as f64;
let expo_256: __m256d = _mm256_set1_pd(expo);
let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst);
let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst);
let mut res_ptr: *mut f64 = res.as_mut_ptr();
let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i;
let span: usize = n >> 2;
for _ in 0..span {
let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr);
ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256);
let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256);
af64_256 = _mm256_or_pd(af64_256, expo_256);
af64_256 = _mm256_sub_pd(af64_256, sub_cst_256);
_mm256_storeu_pd(res_ptr, af64_256);
res_ptr = res_ptr.add(4);
a_ptr = a_ptr.add(1);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref;
reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..])
}
}
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^63-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
#[allow(dead_code)]
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256 = _mm256_set1_pd(offset);
let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr: *const f64 = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
// a += sign(a) * m/2
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
// sign: either 0 or -1
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
// compute the exponents
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
// compute the new mantissa
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
// negate if the sign was negative
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
// stores
_mm256_storeu_si256(res_ptr, out);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^63-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) {
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref;
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256: __m256d = _mm256_set1_pd(offset);
let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64);
// a += sign(a) * m/2
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
// sign: either 0 or -1
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
// compute the exponents
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
// compute the new mantissa
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
// negate if the sign was negative
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
// stores
_mm256_storeu_si256(res_ptr_4xi64, out);
res_ptr_4xi64 = res_ptr_4xi64.add(1);
res_ptr_1xf64 = res_ptr_1xf64.add(4);
}
if !res.len().is_multiple_of(4) {
reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor)
}
}
}
/// # Correctness
/// Only ensured for inputs absoluate value bounded by 2^50-1
/// # Safety
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
#[target_feature(enable = "fma")]
#[allow(dead_code)]
pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64,
};
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64;
let sub_cst: i64 = 1i64 << 51;
let add_cst: f64 = divisor * (3i64 << 51) as f64;
let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst);
let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst);
let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
// read the next value
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
a = _mm256_add_pd(a, add_cst_4);
let mut ai: __m256i = _mm256_castpd_si256(a);
ai = _mm256_and_si256(ai, mantissa_mask_4);
ai = _mm256_sub_epi64(ai, sub_cst_4);
// store the next value
_mm256_storeu_si256(res_ptr, ai);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}

View File

@@ -0,0 +1,162 @@
# ----------------------------------------------------------------------
# This kernel is a direct port of the FFT16 routine from spqlios-arithmetic
# (https://github.com/tfhe/spqlios-arithmetic)
# ----------------------------------------------------------------------
#
.text
.globl fft16_avx2_fma_asm
.hidden fft16_avx2_fma_asm
.p2align 4, 0x90
.type fft16_avx2_fma_asm,@function
fft16_avx2_fma_asm:
.att_syntax prefix
# SysV args: %rdi = re*, %rsi = im*, %rdx = omg*
# stage 0: load inputs
vmovupd (%rdi),%ymm0 # ra0
vmovupd 0x20(%rdi),%ymm1 # ra4
vmovupd 0x40(%rdi),%ymm2 # ra8
vmovupd 0x60(%rdi),%ymm3 # ra12
vmovupd (%rsi),%ymm4 # ia0
vmovupd 0x20(%rsi),%ymm5 # ia4
vmovupd 0x40(%rsi),%ymm6 # ia8
vmovupd 0x60(%rsi),%ymm7 # ia12
# stage 1
vmovupd (%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vmulpd %ymm6,%ymm13,%ymm8
vmulpd %ymm7,%ymm13,%ymm9
vmulpd %ymm2,%ymm13,%ymm10
vmulpd %ymm3,%ymm13,%ymm11
vfmsub231pd %ymm2,%ymm12,%ymm8
vfmsub231pd %ymm3,%ymm12,%ymm9
vfmadd231pd %ymm6,%ymm12,%ymm10
vfmadd231pd %ymm7,%ymm12,%ymm11
vsubpd %ymm8,%ymm0,%ymm2
vsubpd %ymm9,%ymm1,%ymm3
vsubpd %ymm10,%ymm4,%ymm6
vsubpd %ymm11,%ymm5,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vaddpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm4,%ymm4
vaddpd %ymm11,%ymm5,%ymm5
# stage 2
vmovupd 16(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vmulpd %ymm5,%ymm13,%ymm8
vmulpd %ymm7,%ymm12,%ymm9
vmulpd %ymm1,%ymm13,%ymm10
vmulpd %ymm3,%ymm12,%ymm11
vfmsub231pd %ymm1,%ymm12,%ymm8
vfmadd231pd %ymm3,%ymm13,%ymm9
vfmadd231pd %ymm5,%ymm12,%ymm10
vfmsub231pd %ymm7,%ymm13,%ymm11
vsubpd %ymm8,%ymm0,%ymm1
vaddpd %ymm9,%ymm2,%ymm3
vsubpd %ymm10,%ymm4,%ymm5
vaddpd %ymm11,%ymm6,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vsubpd %ymm9,%ymm2,%ymm2
vaddpd %ymm10,%ymm4,%ymm4
vsubpd %ymm11,%ymm6,%ymm6
# stage 3
vmovupd 0x20(%rdx),%ymm12
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3
vmulpd %ymm10,%ymm13,%ymm4
vmulpd %ymm11,%ymm12,%ymm5
vmulpd %ymm8,%ymm13,%ymm6
vmulpd %ymm9,%ymm12,%ymm7
vfmsub231pd %ymm8,%ymm12,%ymm4
vfmadd231pd %ymm9,%ymm13,%ymm5
vfmadd231pd %ymm10,%ymm12,%ymm6
vfmsub231pd %ymm11,%ymm13,%ymm7
vsubpd %ymm4,%ymm0,%ymm8
vaddpd %ymm5,%ymm1,%ymm9
vsubpd %ymm6,%ymm2,%ymm10
vaddpd %ymm7,%ymm3,%ymm11
vaddpd %ymm4,%ymm0,%ymm0
vsubpd %ymm5,%ymm1,%ymm1
vaddpd %ymm6,%ymm2,%ymm2
vsubpd %ymm7,%ymm3,%ymm3
# stage 4
vmovupd 0x40(%rdx),%ymm12
vmovupd 0x60(%rdx),%ymm13
vunpckhpd %ymm1,%ymm0,%ymm4
vunpckhpd %ymm3,%ymm2,%ymm6
vunpckhpd %ymm9,%ymm8,%ymm5
vunpckhpd %ymm11,%ymm10,%ymm7
vunpcklpd %ymm1,%ymm0,%ymm0
vunpcklpd %ymm3,%ymm2,%ymm2
vunpcklpd %ymm9,%ymm8,%ymm1
vunpcklpd %ymm11,%ymm10,%ymm3
vmulpd %ymm6,%ymm13,%ymm8
vmulpd %ymm7,%ymm12,%ymm9
vmulpd %ymm4,%ymm13,%ymm10
vmulpd %ymm5,%ymm12,%ymm11
vfmsub231pd %ymm4,%ymm12,%ymm8
vfmadd231pd %ymm5,%ymm13,%ymm9
vfmadd231pd %ymm6,%ymm12,%ymm10
vfmsub231pd %ymm7,%ymm13,%ymm11
vsubpd %ymm8,%ymm0,%ymm4
vaddpd %ymm9,%ymm1,%ymm5
vsubpd %ymm10,%ymm2,%ymm6
vaddpd %ymm11,%ymm3,%ymm7
vaddpd %ymm8,%ymm0,%ymm0
vsubpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm2,%ymm2
vsubpd %ymm11,%ymm3,%ymm3
vunpckhpd %ymm7,%ymm3,%ymm11
vunpckhpd %ymm5,%ymm1,%ymm9
vunpcklpd %ymm7,%ymm3,%ymm10
vunpcklpd %ymm5,%ymm1,%ymm8
vunpckhpd %ymm6,%ymm2,%ymm3
vunpckhpd %ymm4,%ymm0,%ymm1
vunpcklpd %ymm6,%ymm2,%ymm2
vunpcklpd %ymm4,%ymm0,%ymm0
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
# stores
vmovupd %ymm0,(%rdi) # ra0
vmovupd %ymm1,0x20(%rdi) # ra4
vmovupd %ymm2,0x40(%rdi) # ra8
vmovupd %ymm3,0x60(%rdi) # ra12
vmovupd %ymm4,(%rsi) # ia0
vmovupd %ymm5,0x20(%rsi) # ia4
vmovupd %ymm6,0x40(%rsi) # ia8
vmovupd %ymm7,0x60(%rsi) # ia12
vzeroupper
ret
.size fft16_avx2_fma_asm, .-fft16_avx2_fma_asm
.section .note.GNU-stack,"",@progbits

View File

@@ -0,0 +1,271 @@
use std::arch::x86_64::{
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
};
use crate::reim::{as_arr, as_arr_mut};
#[target_feature(enable = "avx2,fma")]
pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
if m < 16 {
use poulpy_hal::reference::fft64::reim::fft_ref;
fft_ref(m, omg, data);
return;
}
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m == 16 {
fft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
} else if m <= 2048 {
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
fft_rec_16_avx2_fma(m, re, im, omg, 0);
}
}
unsafe extern "sysv64" {
unsafe fn fft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
}
#[target_feature(enable = "avx2,fma")]
fn fft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
unsafe {
fft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
}
}
#[target_feature(enable = "avx2,fma")]
fn fft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
if m <= 2048 {
return fft_bfs_16_avx2_fma(m, re, im, omg, pos);
};
let h: usize = m >> 1;
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
pos = fft_rec_16_avx2_fma(h, re, im, omg, pos);
pos = fft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
pos
}
#[target_feature(enable = "avx2,fma")]
fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
if !log_m.is_multiple_of(2) {
let h: usize = mm >> 1;
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
mm = h
}
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
pos += 4;
}
mm = h
}
for off in (0..m).step_by(16) {
fft16_avx2_fma(
as_arr_mut::<16, f64>(&mut re[off..]),
as_arr_mut::<16, f64>(&mut im[off..]),
as_arr::<16, f64>(&omg[pos..]),
);
pos += 16;
}
pos
}
#[target_feature(enable = "avx2,fma")]
fn twiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
unsafe {
let omx: __m128d = _mm_load_pd(omg.as_ptr());
let omra: __m256d = _mm256_set_m128d(omx, omx);
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut tra: __m256d = _mm256_mul_pd(omi, ui1);
let mut tia: __m256d = _mm256_mul_pd(omi, ur1);
tra = _mm256_fmsub_pd(omr, ur1, tra);
tia = _mm256_fmadd_pd(omr, ui1, tia);
ur1 = _mm256_sub_pd(ur0, tra);
ui1 = _mm256_sub_pd(ui0, tia);
ur0 = _mm256_add_pd(ur0, tra);
ui0 = _mm256_add_pd(ui0, tia);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
r0 = r0.add(4);
r1 = r1.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
}
}
}
#[target_feature(enable = "avx2,fma")]
fn bitwiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
unsafe {
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ur2: __m256d = _mm256_loadu_pd(r2);
let mut ur3: __m256d = _mm256_loadu_pd(r3);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut ui2: __m256d = _mm256_loadu_pd(i2);
let mut ui3: __m256d = _mm256_loadu_pd(i3);
let mut tra: __m256d = _mm256_mul_pd(omai, ui2);
let mut trb: __m256d = _mm256_mul_pd(omai, ui3);
let mut tia: __m256d = _mm256_mul_pd(omai, ur2);
let mut tib: __m256d = _mm256_mul_pd(omai, ur3);
tra = _mm256_fmsub_pd(omar, ur2, tra);
trb = _mm256_fmsub_pd(omar, ur3, trb);
tia = _mm256_fmadd_pd(omar, ui2, tia);
tib = _mm256_fmadd_pd(omar, ui3, tib);
ur2 = _mm256_sub_pd(ur0, tra);
ur3 = _mm256_sub_pd(ur1, trb);
ui2 = _mm256_sub_pd(ui0, tia);
ui3 = _mm256_sub_pd(ui1, tib);
ur0 = _mm256_add_pd(ur0, tra);
ur1 = _mm256_add_pd(ur1, trb);
ui0 = _mm256_add_pd(ui0, tia);
ui1 = _mm256_add_pd(ui1, tib);
tra = _mm256_mul_pd(ombi, ui1);
trb = _mm256_mul_pd(ombr, ui3);
tia = _mm256_mul_pd(ombi, ur1);
tib = _mm256_mul_pd(ombr, ur3);
tra = _mm256_fmsub_pd(ombr, ur1, tra);
trb = _mm256_fmadd_pd(ombi, ur3, trb);
tia = _mm256_fmadd_pd(ombr, ui1, tia);
tib = _mm256_fmsub_pd(ombi, ui3, tib);
ur1 = _mm256_sub_pd(ur0, tra);
ur3 = _mm256_add_pd(ur2, trb);
ui1 = _mm256_sub_pd(ui0, tia);
ui3 = _mm256_add_pd(ui2, tib);
ur0 = _mm256_add_pd(ur0, tra);
ur2 = _mm256_sub_pd(ur2, trb);
ui0 = _mm256_add_pd(ui0, tia);
ui2 = _mm256_sub_pd(ui2, tib);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(r2, ur2);
_mm256_storeu_pd(r3, ur3);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
_mm256_storeu_pd(i2, ui2);
_mm256_storeu_pd(i3, ui3);
r0 = r0.add(4);
r1 = r1.add(4);
r2 = r2.add(4);
r3 = r3.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
i2 = i2.add(4);
i3 = i3.add(4);
}
}
}
#[test]
fn test_fft_avx2_fma() {
use super::*;
#[target_feature(enable = "avx2,fma")]
fn internal(log_m: usize) {
use poulpy_hal::reference::fft64::reim::ReimFFTRef;
let m = 1 << log_m;
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
}
}
if std::is_x86_feature_detected!("avx2") {
for log_m in 0..16 {
unsafe { internal(log_m) }
}
} else {
eprintln!("skipping: CPU lacks avx2");
}
}

View File

@@ -0,0 +1,340 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_add_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let mut bb: *const f64 = b.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let b_256: __m256d = _mm256_loadu_pd(bb);
_mm256_storeu_pd(rr, _mm256_add_pd(a_256, b_256));
rr = rr.add(4);
aa = aa.add(4);
bb = bb.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_add_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_add_pd(r_256, a_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let mut bb: *const f64 = b.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let b_256: __m256d = _mm256_loadu_pd(bb);
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, b_256));
rr = rr.add(4);
aa = aa.add(4);
bb = bb.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_sub_pd(r_256, a_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_negate_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
let span: usize = res.len() >> 2;
unsafe {
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, r_256));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_negate_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
let span: usize = res.len() >> 2;
unsafe {
use std::arch::x86_64::_mm256_set1_pd;
let mut rr: *mut f64 = res.as_mut_ptr();
let mut aa: *const f64 = a.as_ptr();
let neg0: __m256d = _mm256_set1_pd(-0.0);
for _ in 0..span {
let a_256: __m256d = _mm256_loadu_pd(aa);
_mm256_storeu_pd(rr, _mm256_xor_pd(a_256, neg0));
rr = rr.add(4);
aa = aa.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_negate_inplace_avx2_fma(res: &mut [f64]) {
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
let span: usize = res.len() >> 2;
unsafe {
use std::arch::x86_64::_mm256_set1_pd;
let mut rr: *mut f64 = res.as_mut_ptr();
let neg0: __m256d = _mm256_set1_pd(-0.0);
for _ in 0..span {
let r_256: __m256d = _mm256_loadu_pd(rr);
_mm256_storeu_pd(rr, _mm256_xor_pd(r_256, neg0));
rr = rr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_addmul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
let mut br_ptr: *const f64 = br.as_ptr();
let mut bi_ptr: *const f64 = bi.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let mut rr: __m256d = _mm256_loadu_pd(rr_ptr);
let mut ri: __m256d = _mm256_loadu_pd(ri_ptr);
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(br_ptr);
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
rr = _mm256_fmsub_pd(ai, bi, rr);
rr = _mm256_fmsub_pd(ar, br, rr);
ri = _mm256_fmadd_pd(ar, bi, ri);
ri = _mm256_fmadd_pd(ai, br, ri);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
br_ptr = br_ptr.add(4);
bi_ptr = bi_ptr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_mul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
let mut br_ptr: *const f64 = br.as_ptr();
let mut bi_ptr: *const f64 = bi.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(br_ptr);
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
let t1: __m256d = _mm256_mul_pd(ai, bi);
let t2: __m256d = _mm256_mul_pd(ar, bi);
let rr: __m256d = _mm256_fmsub_pd(ar, br, t1);
let ri: __m256d = _mm256_fmadd_pd(ai, br, t2);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
br_ptr = br_ptr.add(4);
bi_ptr = bi_ptr.add(4);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_mul_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
unsafe {
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
let mut ar_ptr: *const f64 = ar.as_ptr();
let mut ai_ptr: *const f64 = ai.as_ptr();
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
for _ in 0..(m >> 2) {
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
let br: __m256d = _mm256_loadu_pd(rr_ptr);
let bi: __m256d = _mm256_loadu_pd(ri_ptr);
let t1: __m256d = _mm256_mul_pd(ai, bi);
let t2: __m256d = _mm256_mul_pd(ar, bi);
let rr = _mm256_fmsub_pd(ar, br, t1);
let ri = _mm256_fmadd_pd(ai, br, t2);
_mm256_storeu_pd(rr_ptr, rr);
_mm256_storeu_pd(ri_ptr, ri);
rr_ptr = rr_ptr.add(4);
ri_ptr = ri_ptr.add(4);
ar_ptr = ar_ptr.add(4);
ai_ptr = ai_ptr.add(4);
}
}
}

View File

@@ -0,0 +1,181 @@
# ----------------------------------------------------------------------
# This kernel is a direct port of the IFFT16 routine from spqlios-arithmetic
# (https://github.com/tfhe/spqlios-arithmetic)
# ----------------------------------------------------------------------
#
.text
.globl ifft16_avx2_fma_asm
.hidden ifft16_avx2_fma_asm
.p2align 4, 0x90
.type ifft16_avx2_fma_asm,@function
ifft16_avx2_fma_asm:
.att_syntax prefix
vmovupd (%rdi),%ymm0 # ra0
vmovupd 0x20(%rdi),%ymm1 # ra4
vmovupd 0x40(%rdi),%ymm2 # ra8
vmovupd 0x60(%rdi),%ymm3 # ra12
vmovupd (%rsi),%ymm4 # ia0
vmovupd 0x20(%rsi),%ymm5 # ia4
vmovupd 0x40(%rsi),%ymm6 # ia8
vmovupd 0x60(%rsi),%ymm7 # ia12
1:
vmovupd 0x00(%rdx),%ymm12
vmovupd 0x20(%rdx),%ymm13
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw)
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw)
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw)
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw)
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw)
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw)
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw)
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw)
vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4)
vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6)
vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5)
vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7)
vunpcklpd %ymm1,%ymm0,%ymm0
vunpcklpd %ymm3,%ymm2,%ymm2
vunpcklpd %ymm9,%ymm8,%ymm1
vunpcklpd %ymm11,%ymm10,%ymm3
# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm4,%ymm0,%ymm8 # retw
vsubpd %ymm5,%ymm1,%ymm9 # reitw
vsubpd %ymm6,%ymm2,%ymm10 # imtw
vsubpd %ymm7,%ymm3,%ymm11 # imitw
vaddpd %ymm4,%ymm0,%ymm0
vaddpd %ymm5,%ymm1,%ymm1
vaddpd %ymm6,%ymm2,%ymm2
vaddpd %ymm7,%ymm3,%ymm3
# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw)
vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw)
vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw)
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw)
vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw)
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw)
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1)
vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3)
vunpcklpd %ymm7,%ymm3,%ymm10
vunpcklpd %ymm5,%ymm1,%ymm8
vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9)
vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11)
vunpcklpd %ymm6,%ymm2,%ymm2
vunpcklpd %ymm4,%ymm0,%ymm0
2:
vmovupd 0x40(%rdx),%ymm12
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i'
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r'
# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13)
vsubpd %ymm8,%ymm0,%ymm4 # retw
vsubpd %ymm9,%ymm1,%ymm5 # reitw
vsubpd %ymm10,%ymm2,%ymm6 # imtw
vsubpd %ymm11,%ymm3,%ymm7 # imitw
vaddpd %ymm8,%ymm0,%ymm0
vaddpd %ymm9,%ymm1,%ymm1
vaddpd %ymm10,%ymm2,%ymm2
vaddpd %ymm11,%ymm3,%ymm3
# multiply 4,5,6,7 by 12,13, result to 8,9,10,11
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw)
vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw)
vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw)
vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw)
vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw)
vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw)
vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw)
vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw)
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
3:
vmovupd 0x60(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13)
# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm1,%ymm0,%ymm8 # retw
vsubpd %ymm3,%ymm2,%ymm9 # reitw
vsubpd %ymm5,%ymm4,%ymm10 # imtw
vsubpd %ymm7,%ymm6,%ymm11 # imitw
vaddpd %ymm1,%ymm0,%ymm0
vaddpd %ymm3,%ymm2,%ymm2
vaddpd %ymm5,%ymm4,%ymm4
vaddpd %ymm7,%ymm6,%ymm6
# multiply 8,9,10,11 by 12,13, result to 1,3,5,7
# twiddles use reom=ymm12, imom=ymm13
# invtwiddles use reom=ymm13, imom=-ymm12
vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw)
vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw)
vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw)
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw)
vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw)
vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw)
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
4:
vmovupd 0x70(%rdx),%xmm12
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13)
# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13)
vsubpd %ymm2,%ymm0,%ymm8 # retw1
vsubpd %ymm3,%ymm1,%ymm9 # retw2
vsubpd %ymm6,%ymm4,%ymm10 # imtw1
vsubpd %ymm7,%ymm5,%ymm11 # imtw2
vaddpd %ymm2,%ymm0,%ymm0
vaddpd %ymm3,%ymm1,%ymm1
vaddpd %ymm6,%ymm4,%ymm4
vaddpd %ymm7,%ymm5,%ymm5
# multiply 8,9,10,11 by 12,13, result to 2,3,6,7
# twiddles use reom=ymm12, imom=ymm13
vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai
vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai
vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai
vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai
vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0
vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0
vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4
5:
vmovupd %ymm0,(%rdi) # ra0
vmovupd %ymm1,0x20(%rdi) # ra4
vmovupd %ymm2,0x40(%rdi) # ra8
vmovupd %ymm3,0x60(%rdi) # ra12
vmovupd %ymm4,(%rsi) # ia0
vmovupd %ymm5,0x20(%rsi) # ia4
vmovupd %ymm6,0x40(%rsi) # ia8
vmovupd %ymm7,0x60(%rsi) # ia12
vzeroupper
ret
.size ifft16_avx_fma, .-ifft16_avx_fma
.section .note.GNU-stack,"",@progbits

View File

@@ -0,0 +1,264 @@
use std::arch::x86_64::{
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
};
use crate::reim::{as_arr, as_arr_mut};
#[target_feature(enable = "avx2,fma")]
pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
if m < 16 {
use poulpy_hal::reference::fft64::reim::ifft_ref;
ifft_ref(m, omg, data);
return;
}
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m == 16 {
ifft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
} else if m <= 2048 {
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
ifft_rec_16_avx2_fma(m, re, im, omg, 0);
}
}
unsafe extern "sysv64" {
unsafe fn ifft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
}
#[target_feature(enable = "avx2,fma")]
fn ifft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
unsafe {
ifft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
}
}
#[target_feature(enable = "avx2,fma")]
fn ifft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
if m <= 2048 {
return ifft_bfs_16_avx2_fma(m, re, im, omg, pos);
};
let h: usize = m >> 1;
pos = ifft_rec_16_avx2_fma(h, re, im, omg, pos);
pos = ifft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
pos
}
#[target_feature(enable = "avx2,fma")]
fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
for off in (0..m).step_by(16) {
ifft16_avx2_fma(
as_arr_mut::<16, f64>(&mut re[off..]),
as_arr_mut::<16, f64>(&mut im[off..]),
as_arr::<16, f64>(&omg[pos..]),
);
pos += 16;
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
pos += 4;
}
h = mm;
}
if !log_m.is_multiple_of(2) {
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
pos += 2;
}
pos
}
#[target_feature(enable = "avx2,fma")]
fn inv_twiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
unsafe {
let omx: __m128d = _mm_load_pd(omg.as_ptr());
let omra: __m256d = _mm256_set_m128d(omx, omx);
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let tra = _mm256_sub_pd(ur0, ur1);
let tia = _mm256_sub_pd(ui0, ui1);
ur0 = _mm256_add_pd(ur0, ur1);
ui0 = _mm256_add_pd(ui0, ui1);
ur1 = _mm256_mul_pd(omi, tia);
ui1 = _mm256_mul_pd(omi, tra);
ur1 = _mm256_fmsub_pd(omr, tra, ur1);
ui1 = _mm256_fmadd_pd(omr, tia, ui1);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
r0 = r0.add(4);
r1 = r1.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
}
}
}
#[target_feature(enable = "avx2,fma")]
fn inv_bitwiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
unsafe {
let mut r0: *mut f64 = re.as_mut_ptr();
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
let mut i0: *mut f64 = im.as_mut_ptr();
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
for _ in (0..h).step_by(4) {
let mut ur0: __m256d = _mm256_loadu_pd(r0);
let mut ur1: __m256d = _mm256_loadu_pd(r1);
let mut ur2: __m256d = _mm256_loadu_pd(r2);
let mut ur3: __m256d = _mm256_loadu_pd(r3);
let mut ui0: __m256d = _mm256_loadu_pd(i0);
let mut ui1: __m256d = _mm256_loadu_pd(i1);
let mut ui2: __m256d = _mm256_loadu_pd(i2);
let mut ui3: __m256d = _mm256_loadu_pd(i3);
let mut tra: __m256d = _mm256_sub_pd(ur0, ur1);
let mut trb: __m256d = _mm256_sub_pd(ur2, ur3);
let mut tia: __m256d = _mm256_sub_pd(ui0, ui1);
let mut tib: __m256d = _mm256_sub_pd(ui2, ui3);
ur0 = _mm256_add_pd(ur0, ur1);
ur2 = _mm256_add_pd(ur2, ur3);
ui0 = _mm256_add_pd(ui0, ui1);
ui2 = _mm256_add_pd(ui2, ui3);
ur1 = _mm256_mul_pd(omai, tia);
ur3 = _mm256_mul_pd(omar, tib);
ui1 = _mm256_mul_pd(omai, tra);
ui3 = _mm256_mul_pd(omar, trb);
ur1 = _mm256_fmsub_pd(omar, tra, ur1);
ur3 = _mm256_fmadd_pd(omai, trb, ur3);
ui1 = _mm256_fmadd_pd(omar, tia, ui1);
ui3 = _mm256_fmsub_pd(omai, tib, ui3);
tra = _mm256_sub_pd(ur0, ur2);
trb = _mm256_sub_pd(ur1, ur3);
tia = _mm256_sub_pd(ui0, ui2);
tib = _mm256_sub_pd(ui1, ui3);
ur0 = _mm256_add_pd(ur0, ur2);
ur1 = _mm256_add_pd(ur1, ur3);
ui0 = _mm256_add_pd(ui0, ui2);
ui1 = _mm256_add_pd(ui1, ui3);
ur2 = _mm256_mul_pd(ombi, tia);
ur3 = _mm256_mul_pd(ombi, tib);
ui2 = _mm256_mul_pd(ombi, tra);
ui3 = _mm256_mul_pd(ombi, trb);
ur2 = _mm256_fmsub_pd(ombr, tra, ur2);
ur3 = _mm256_fmsub_pd(ombr, trb, ur3);
ui2 = _mm256_fmadd_pd(ombr, tia, ui2);
ui3 = _mm256_fmadd_pd(ombr, tib, ui3);
_mm256_storeu_pd(r0, ur0);
_mm256_storeu_pd(r1, ur1);
_mm256_storeu_pd(r2, ur2);
_mm256_storeu_pd(r3, ur3);
_mm256_storeu_pd(i0, ui0);
_mm256_storeu_pd(i1, ui1);
_mm256_storeu_pd(i2, ui2);
_mm256_storeu_pd(i3, ui3);
r0 = r0.add(4);
r1 = r1.add(4);
r2 = r2.add(4);
r3 = r3.add(4);
i0 = i0.add(4);
i1 = i1.add(4);
i2 = i2.add(4);
i3 = i3.add(4);
}
}
}
#[test]
fn test_ifft_avx2_fma() {
use super::*;
#[target_feature(enable = "avx2,fma")]
fn internal(log_m: usize) {
use poulpy_hal::reference::fft64::reim::ReimIFFTRef;
let m: usize = 1 << log_m;
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
}
}
if std::is_x86_feature_detected!("avx2") {
for log_m in 0..16 {
unsafe { internal(log_m) }
}
} else {
eprintln!("skipping: CPU lacks avx2");
}
}

View File

@@ -0,0 +1,72 @@
// ----------------------------------------------------------------------
// DISCLAIMER
//
// This module contains code that has been directly ported from the
// spqlios-arithmetic library
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
// under the Apache License, Version 2.0.
//
// The porting process from C to Rust was done with minimal changes
// in order to preserve the semantics and performance characteristics
// of the original implementation.
//
// Both Poulpy and spqlios-arithmetic are distributed under the terms
// of the Apache License, Version 2.0. See the LICENSE file for details.
//
// ----------------------------------------------------------------------
#![allow(bad_asm_style)]
mod conversion;
mod fft_avx2_fma;
mod fft_vec_avx2_fma;
mod ifft_avx2_fma;
use std::arch::global_asm;
pub(crate) use conversion::*;
pub(crate) use fft_vec_avx2_fma::*;
use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTTable, ReimIFFTTable};
use rand_distr::num_traits::{Float, FloatConst};
use crate::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
global_asm!(
include_str!("fft16_avx2_fma.s"),
include_str!("ifft16_avx2_fma.s")
);
#[inline(always)]
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {
debug_assert!(x.len() >= SIZE);
unsafe { &*(x.as_ptr() as *const [R; SIZE]) }
}
#[inline(always)]
pub(crate) fn as_arr_mut<const SIZE: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; SIZE] {
debug_assert!(x.len() >= SIZE);
unsafe { &mut *(x.as_mut_ptr() as *mut [R; SIZE]) }
}
pub struct ReimFFTAvx;
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTAvx {
#[inline(always)]
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
unsafe {
fft_avx2_fma(table.m(), table.omg(), data);
}
}
}
pub struct ReimIFFTAvx;
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTAvx {
#[inline(always)]
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
unsafe {
ifft_avx2_fma(table.m(), table.omg(), data);
}
}
}

View File

@@ -0,0 +1,258 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
unsafe {
let mut src_ptr: *const __m256d = src.as_ptr().add(blk << 2) as *const __m256d; // src + 4*blk
let mut dst_ptr: *mut __m256d = dst.as_mut_ptr() as *mut __m256d;
let step: usize = m >> 2;
// Each iteration copies 4 doubles; advance src by m doubles each row
for _ in 0..2 * rows {
let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64);
_mm256_storeu_pd(dst_ptr as *mut f64, v);
dst_ptr = dst_ptr.add(1);
src_ptr = src_ptr.add(step);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim4_save_1blk_to_reim_avx<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
unsafe {
let off: usize = blk * 4;
let src_ptr: *const f64 = src.as_ptr();
let s0: __m256d = _mm256_loadu_pd(src_ptr);
let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4));
let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off);
let d1_ptr: *mut f64 = d0_ptr.add(m);
if OVERWRITE {
_mm256_storeu_pd(d0_ptr, s0);
_mm256_storeu_pd(d1_ptr, s1);
} else {
let d0: __m256d = _mm256_loadu_pd(d0_ptr);
let d1: __m256d = _mm256_loadu_pd(d1_ptr);
_mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0));
_mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1));
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim4_save_2blk_to_reim_avx<const OVERWRITE: bool>(
m: usize, //
blk: usize, // block index
dst: &mut [f64], //
src: &[f64], // 16 doubles [re1(4), im1(4), re2(4), im2(4)]
) {
use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
unsafe {
let off: usize = blk * 4;
let src_ptr: *const f64 = src.as_ptr();
let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off);
let d1_ptr: *mut f64 = d0_ptr.add(m);
let d2_ptr: *mut f64 = d1_ptr.add(m);
let d3_ptr: *mut f64 = d2_ptr.add(m);
let s0: __m256d = _mm256_loadu_pd(src_ptr);
let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4));
let s2: __m256d = _mm256_loadu_pd(src_ptr.add(8));
let s3: __m256d = _mm256_loadu_pd(src_ptr.add(12));
if OVERWRITE {
_mm256_storeu_pd(d0_ptr, s0);
_mm256_storeu_pd(d1_ptr, s1);
_mm256_storeu_pd(d2_ptr, s2);
_mm256_storeu_pd(d3_ptr, s3);
} else {
let d0: __m256d = _mm256_loadu_pd(d0_ptr);
let d1: __m256d = _mm256_loadu_pd(d1_ptr);
let d2: __m256d = _mm256_loadu_pd(d2_ptr);
let d3: __m256d = _mm256_loadu_pd(d3_ptr);
_mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0));
_mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1));
_mm256_storeu_pd(d2_ptr, _mm256_add_pd(d2, s2));
_mm256_storeu_pd(d3_ptr, _mm256_add_pd(d3, s3));
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2", enable = "fma")]
pub fn reim4_vec_mat1col_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
#[cfg(debug_assertions)]
{
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
}
unsafe {
use std::arch::x86_64::{_mm256_add_pd, _mm256_sub_pd};
let mut re1: __m256d = _mm256_setzero_pd();
let mut im1: __m256d = _mm256_setzero_pd();
let mut re2: __m256d = _mm256_setzero_pd();
let mut im2: __m256d = _mm256_setzero_pd();
let mut u_ptr: *const f64 = u.as_ptr();
let mut v_ptr: *const f64 = v.as_ptr();
for _ in 0..nrows {
let ur: __m256d = _mm256_loadu_pd(u_ptr);
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
let vr: __m256d = _mm256_loadu_pd(v_ptr);
let vi: __m256d = _mm256_loadu_pd(v_ptr.add(4));
// re1 = re1 + ur*vr;
re1 = _mm256_fmadd_pd(ur, vr, re1);
// im1 = im1 + ur*d;
im1 = _mm256_fmadd_pd(ur, vi, im1);
// re2 = re2 + ui*d;
re2 = _mm256_fmadd_pd(ui, vi, re2);
// im2 = im2 + ui*vr;
im2 = _mm256_fmadd_pd(ui, vr, im2);
u_ptr = u_ptr.add(8);
v_ptr = v_ptr.add(8);
}
// re1 - re2
_mm256_storeu_pd(dst.as_mut_ptr(), _mm256_sub_pd(re1, re2));
// im1 + im2
_mm256_storeu_pd(dst.as_mut_ptr().add(4), _mm256_add_pd(im1, im2));
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2", enable = "fma")]
pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
#[cfg(debug_assertions)]
{
assert!(
dst.len() >= 8,
"dst must be at least 8 doubles but is {}",
dst.len()
);
assert!(
u.len() >= nrows * 8,
"u must be at least nrows={} * 8 doubles but is {}",
nrows,
u.len()
);
assert!(
v.len() >= nrows * 16,
"v must be at least nrows={} * 16 doubles but is {}",
nrows,
v.len()
);
}
unsafe {
let mut re1: __m256d = _mm256_setzero_pd();
let mut im1: __m256d = _mm256_setzero_pd();
let mut re2: __m256d = _mm256_setzero_pd();
let mut im2: __m256d = _mm256_setzero_pd();
let mut u_ptr: *const f64 = u.as_ptr();
let mut v_ptr: *const f64 = v.as_ptr();
for _ in 0..nrows {
let ur: __m256d = _mm256_loadu_pd(u_ptr);
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
let ar: __m256d = _mm256_loadu_pd(v_ptr);
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
let br: __m256d = _mm256_loadu_pd(v_ptr.add(8));
let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12));
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
re1 = _mm256_fmsub_pd(ui, ai, re1);
re2 = _mm256_fmsub_pd(ui, bi, re2);
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
im1 = _mm256_fmadd_pd(ur, ai, im1);
im2 = _mm256_fmadd_pd(ur, bi, im2);
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
re1 = _mm256_fmsub_pd(ur, ar, re1);
re2 = _mm256_fmsub_pd(ur, br, re2);
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
im1 = _mm256_fmadd_pd(ui, ar, im1);
im2 = _mm256_fmadd_pd(ui, br, im2);
u_ptr = u_ptr.add(8);
v_ptr = v_ptr.add(16);
}
_mm256_storeu_pd(dst.as_mut_ptr(), re1);
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
_mm256_storeu_pd(dst.as_mut_ptr().add(8), re2);
_mm256_storeu_pd(dst.as_mut_ptr().add(12), im2);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2", enable = "fma")]
pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
#[cfg(debug_assertions)]
{
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(
v.len() >= nrows * 16,
"v must be at least nrows * 16 doubles"
);
}
unsafe {
let mut re1: __m256d = _mm256_setzero_pd();
let mut im1: __m256d = _mm256_setzero_pd();
let mut u_ptr: *const f64 = u.as_ptr();
let mut v_ptr: *const f64 = v.as_ptr().add(8); // Offset to 2nd column
for _ in 0..nrows {
let ur: __m256d = _mm256_loadu_pd(u_ptr);
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
let ar: __m256d = _mm256_loadu_pd(v_ptr);
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
re1 = _mm256_fmsub_pd(ui, ai, re1);
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
im1 = _mm256_fmadd_pd(ur, ai, im1);
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
re1 = _mm256_fmsub_pd(ur, ar, re1);
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
im1 = _mm256_fmadd_pd(ui, ar, im1);
u_ptr = u_ptr.add(8);
v_ptr = v_ptr.add(16);
}
_mm256_storeu_pd(dst.as_mut_ptr(), re1);
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
}
}

View File

@@ -0,0 +1,3 @@
mod arithmetic_avx;
pub(crate) use arithmetic_avx::*;

View File

@@ -0,0 +1,81 @@
use std::marker::PhantomData;
use poulpy_hal::{
DEFAULTALIGN, alloc_aligned,
api::ScratchFromBytes,
layouts::{Backend, Scratch, ScratchOwned},
oep::{ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeSliceImpl},
};
use crate::FFT64Avx;
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64Avx {
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
let data: Vec<u8> = alloc_aligned(size);
ScratchOwned {
data,
_phantom: PhantomData,
}
}
}
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64Avx
where
B: ScratchFromBytesImpl<B>,
{
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
Scratch::from_bytes(&mut scratch.data)
}
}
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64Avx {
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
}
}
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64Avx {
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
let ptr: *const u8 = scratch.data.as_ptr();
let self_len: usize = scratch.data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
self_len.saturating_sub(aligned_offset)
}
}
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64Avx
where
B: ScratchFromBytesImpl<B>,
{
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
unsafe {
(
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
Scratch::from_bytes(rem_slice),
)
}
}
}
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr: *mut u8 = data.as_mut_ptr();
let self_len: usize = data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
unsafe {
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
(take_slice, rem_slice)
}
} else {
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

66
poulpy-cpu-avx/src/svp.rs Normal file
View File

@@ -0,0 +1,66 @@
use poulpy_hal::{
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
oep::{
SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl,
SvpPrepareImpl,
},
reference::fft64::svp::{svp_apply_dft_to_dft, svp_apply_dft_to_dft_inplace, svp_prepare},
};
use crate::{FFT64Avx, module::FFT64ModuleHandle};
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64Avx {
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
SvpPPolOwned::from_bytes(n, cols, bytes)
}
}
unsafe impl SvpPPolAllocImpl<Self> for FFT64Avx {
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
SvpPPolOwned::alloc(n, cols)
}
}
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64Avx {
fn svp_ppol_bytes_of_impl(n: usize, cols: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size_of::<f64>()
}
}
unsafe impl SvpPrepareImpl<Self> for FFT64Avx {
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: SvpPPolToMut<Self>,
A: ScalarZnxToRef,
{
svp_prepare(module.get_fft_table(), res, res_col, a, a_col);
}
}
unsafe impl SvpApplyDftToDftImpl<Self> for FFT64Avx {
fn svp_apply_dft_to_dft_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>,
B: VecZnxDftToRef<Self>,
{
svp_apply_dft_to_dft(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Avx {
fn svp_apply_dft_to_dft_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>,
{
svp_apply_dft_to_dft_inplace(res, res_col, a, a_col);
}
}

125
poulpy-cpu-avx/src/tests.rs Normal file
View File

@@ -0,0 +1,125 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
use crate::FFT64Avx;
#[cfg(test)]
mod poulpy_cpu_avx {
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
cross_backend_test_suite! {
mod vec_znx,
backend_ref = poulpy_cpu_ref::FFT64Ref,
backend_test = crate::FFT64Avx,
size = 1<<8,
base2k = 12,
tests = {
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace,
test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh,
test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace,
test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate,
test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace,
test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate,
test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace,
test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism,
test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace,
test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one,
test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace,
test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize,
test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace,
test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring,
test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring,
test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy,
}
}
cross_backend_test_suite! {
mod svp,
backend_ref = poulpy_cpu_ref::FFT64Ref,
backend_test = crate::FFT64Avx,
size = 1<<8,
base2k = 12,
tests = {
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
}
}
cross_backend_test_suite! {
mod vec_znx_big,
backend_ref = poulpy_cpu_ref::FFT64Ref,
backend_test = crate::FFT64Avx,
size = 1<<8,
base2k = 12,
tests = {
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace,
}
}
cross_backend_test_suite! {
mod vec_znx_dft,
backend_ref = poulpy_cpu_ref::FFT64Ref,
backend_test = crate::FFT64Avx,
size = 1<<8,
base2k = 12,
tests = {
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
}
}
cross_backend_test_suite! {
mod vmp,
backend_ref = poulpy_cpu_ref::FFT64Ref,
backend_test = crate::FFT64Avx,
size = 1<<8,
base2k = 12,
tests = {
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
}
}
backend_test_suite! {
mod sampling,
backend = crate::FFT64Avx,
size = 1<<12,
tests = {
test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform,
test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal,
test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal,
}
}
}
#[test]
fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
test_bivariate_tensoring(&module);
}

View File

@@ -0,0 +1,549 @@
use poulpy_hal::{
api::{
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl,
},
reference::vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace,
vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy,
vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes,
vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one, vec_znx_mul_xp_minus_one_inplace,
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero,
},
source::Source,
};
use crate::FFT64Avx;
unsafe impl VecZnxZeroImpl<Self> for FFT64Avx {
fn vec_znx_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
vec_znx_zero::<_, FFT64Avx>(res, res_col);
}
}
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_normalize_tmp_bytes(module.n())
}
}
unsafe impl VecZnxNormalizeImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
res_base2k: usize,
res: &mut R,
res_col: usize,
a_base2k: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(res_base2k, res, res_col, a_base2k, a, a_col, carry);
}
}
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
{
fn vec_znx_normalize_inplace_impl<R>(
module: &Module<Self>,
base2k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
}
}
unsafe impl VecZnxAddImpl<Self> for FFT64Avx {
fn vec_znx_add_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
vec_znx_add::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64Avx {
fn vec_znx_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_add_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64Avx {
fn vec_znx_add_scalar_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
vec_znx_add_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
}
}
unsafe impl VecZnxAddScalarImpl<Self> for FFT64Avx {
fn vec_znx_add_scalar_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
b_limb: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
{
vec_znx_add_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
}
}
unsafe impl VecZnxSubImpl<Self> for FFT64Avx {
fn vec_znx_sub_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
vec_znx_sub::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxSubScalarImpl<Self> for FFT64Avx {
fn vec_znx_sub_scalar_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
b_limb: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
{
vec_znx_sub_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
}
}
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_scalar_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
vec_znx_sub_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
}
}
unsafe impl VecZnxNegateImpl<Self> for FFT64Avx {
fn vec_znx_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_negate::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
vec_znx_negate_inplace::<R, Self>(res, res_col);
}
}
unsafe impl VecZnxLshTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_lsh_tmp_bytes(module.n())
}
}
unsafe impl VecZnxRshTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_rsh_tmp_bytes(module.n())
}
}
unsafe impl VecZnxLshImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
unsafe impl VecZnxRshImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
unsafe impl VecZnxRotateImpl<Self> for FFT64Avx {
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_rotate::<R, A, Self>(p, res, res_col, a, a_col);
}
}
unsafe impl VecZnxRotateInplaceTmpBytesImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
{
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_rotate_inplace_tmp_bytes(module.n())
}
}
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
Self: VecZnxRotateInplaceTmpBytesImpl<Self>,
{
fn vec_znx_rotate_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: VecZnxToMut,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_rotate_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_rotate_inplace::<R, Self>(p, res, res_col, tmp);
}
}
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64Avx {
fn vec_znx_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_automorphism::<R, A, Self>(p, res, res_col, a, a_col);
}
}
unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_automorphism_inplace_tmp_bytes(module.n())
}
}
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
Self: VecZnxAutomorphismInplaceTmpBytesImpl<Self>,
{
fn vec_znx_automorphism_inplace_impl<R>(
module: &Module<Self>,
p: i64,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_automorphism_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_automorphism_inplace::<R, Self>(p, res, res_col, tmp);
}
}
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64Avx {
fn vec_znx_mul_xp_minus_one_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_mul_xp_minus_one::<R, A, Self>(p, res, res_col, a, a_col);
}
}
unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
Self: VecZnxMulXpMinusOneImpl<Self>,
{
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n())
}
}
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64Avx {
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
module: &Module<Self>,
p: i64,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_mul_xp_minus_one_inplace::<R, Self>(p, res, res_col, tmp);
}
}
unsafe impl VecZnxSplitRingTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_split_ring_tmp_bytes(module.n())
}
}
unsafe impl VecZnxSplitRingImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxSplitRingTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_split_ring_impl<R, A>(
module: &Module<Self>,
res: &mut [R],
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::<i64>());
vec_znx_split_ring::<R, A, Self>(res, res_col, a, a_col, tmp);
}
}
unsafe impl VecZnxMergeRingsTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_merge_rings_tmp_bytes(module.n())
}
}
unsafe impl VecZnxMergeRingsImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxMergeRingsTmpBytes,
{
fn vec_znx_merge_rings_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &[A],
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::<i64>());
vec_znx_merge_rings::<R, A, Self>(res, res_col, a, a_col, tmp);
}
}
unsafe impl VecZnxSwitchRingImpl<Self> for FFT64Avx
where
Self: VecZnxCopyImpl<Self>,
{
fn vec_znx_switch_ring_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_switch_ring::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxCopyImpl<Self> for FFT64Avx {
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_copy::<R, A, Self>(res, res_col, a, a_col)
}
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Avx {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -0,0 +1,333 @@
use crate::FFT64Avx;
use poulpy_hal::{
api::{TakeSlice, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNormalizeTmpBytes},
layouts::{
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
ZnxInfos, ZnxView, ZnxViewMut,
},
oep::{
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
fft64::vec_znx_big::{
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
},
znx::{znx_copy_ref, znx_zero_ref},
},
source::Source,
};
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Avx {
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
}
}
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Avx {
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
VecZnxBig::alloc(n, cols, size)
}
}
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Avx {
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
VecZnxBig::from_bytes(n, cols, size, bytes)
}
}
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Avx {
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let mut res: VecZnxBig<&mut [u8], FFT64Avx> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let min_size: usize = res_size.min(a_size);
for j in 0..min_size {
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res_size {
znx_zero_ref(res.at_mut(res_col, j));
}
}
}
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxBigAddImpl<Self> for FFT64Avx {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
B: VecZnxBigToRef<Self>,
{
vec_znx_big_add(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Avx {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_add_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Avx {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
B: VecZnxToRef,
{
vec_znx_big_add_small(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Avx {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
vec_znx_big_add_small_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubImpl<Self> for FFT64Avx {
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
B: VecZnxBigToRef<Self>,
{
vec_znx_big_sub(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Avx {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Avx {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Avx {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
B: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_small_a(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Avx {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
vec_znx_big_sub_small_a_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Avx {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
B: VecZnxToRef,
{
vec_znx_big_sub_small_b(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Avx {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
vec_znx_big_sub_small_b_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Avx {
fn vec_znx_big_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_negate(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_big_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxBigToMut<Self>,
{
vec_znx_big_negate_inplace(res, res_col);
}
}
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_big_normalize_tmp_bytes(module.n())
}
}
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self>,
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Avx {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_automorphism(p, res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_big_automorphism_inplace_tmp_bytes(module.n())
}
}
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Avx
where
Module<Self>: VecZnxBigAutomorphismInplaceTmpBytes,
{
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace_impl<R>(
module: &Module<Self>,
p: i64,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
}
}

View File

@@ -0,0 +1,203 @@
use poulpy_hal::{
layouts::{
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToRef,
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy,
vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply,
vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa,
},
};
use crate::{FFT64Avx, module::FFT64ModuleHandle};
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64Avx {
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
VecZnxDft::<Vec<u8>, Self>::from_bytes(n, cols, size, bytes)
}
}
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64Avx {
fn vec_znx_dft_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
}
}
unsafe impl VecZnxDftAllocImpl<Self> for FFT64Avx {
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
VecZnxDftOwned::alloc(n, cols, size)
}
}
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_idft_apply_tmp_bytes_impl(_module: &Module<Self>) -> usize {
0
}
}
unsafe impl VecZnxIdftApplyImpl<Self> for FFT64Avx {
fn vec_znx_idft_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
_scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_idft_apply(module.get_ifft_table(), res, res_col, a, a_col);
}
}
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for FFT64Avx {
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxDftToMut<Self>,
{
vec_znx_idft_apply_tmpa(module.get_ifft_table(), res, res_col, a, a_col);
}
}
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for FFT64Avx {
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<Self>, res: VecZnxDft<D, FFT64Avx>) -> VecZnxBig<D, FFT64Avx>
where
VecZnxDft<D, FFT64Avx>: VecZnxDftToMut<Self>,
{
vec_znx_idft_apply_consume(module.get_ifft_table(), res)
}
}
unsafe impl VecZnxDftApplyImpl<Self> for FFT64Avx {
fn vec_znx_dft_apply_impl<R, A>(
module: &Module<Self>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxToRef,
{
vec_znx_dft_apply(module.get_fft_table(), step, offset, res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftAddImpl<Self> for FFT64Avx {
fn vec_znx_dft_add_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
B: VecZnxDftToRef<Self>,
{
vec_znx_dft_add(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_add_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
a_scale: i64,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale);
}
}
unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_impl<R, A, B>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
B: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftCopyImpl<Self> for FFT64Avx {
fn vec_znx_dft_copy_impl<R, A>(
_module: &Module<Self>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_copy(step, offset, res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Avx {
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxDftToMut<Self>,
{
vec_znx_dft_zero(res, res_col);
}
}

153
poulpy-cpu-avx/src/vmp.rs Normal file
View File

@@ -0,0 +1,153 @@
use poulpy_hal::{
api::{TakeSlice, VmpPrepareTmpBytes},
layouts::{
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
VmpPMatToMut, VmpPMatToRef, ZnxInfos,
},
oep::{
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl,
},
reference::fft64::vmp::{
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
vmp_zero,
},
};
use crate::{FFT64Avx, module::FFT64ModuleHandle};
unsafe impl VmpPMatAllocBytesImpl<Self> for FFT64Avx {
fn vmp_pmat_bytes_of_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
}
}
unsafe impl VmpPMatAllocImpl<Self> for FFT64Avx {
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<Self> {
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
}
}
unsafe impl VmpApplyDftToDftImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
FFT64Avx: VmpApplyDftToDftTmpBytesImpl<Self>,
{
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<Self>, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<Self>)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
C: VmpPMatToRef<Self>,
{
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
let a: VecZnxDft<&[u8], Self> = a.to_ref();
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
let (tmp, _) = scratch.take_slice(
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
module,
res.size(),
a.size(),
pmat.rows(),
pmat.cols_in(),
pmat.cols_out(),
pmat.size(),
) / size_of::<f64>(),
);
vmp_apply_dft_to_dft(&mut res, &a, &pmat, tmp);
}
}
unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Avx
where
Scratch<Self>: TakeSlice,
FFT64Avx: VmpApplyDftToDftTmpBytesImpl<Self>,
{
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
module: &Module<Self>,
res: &mut R,
a: &A,
pmat: &C,
limb_offset: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
C: VmpPMatToRef<Self>,
{
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
let a: VecZnxDft<&[u8], Self> = a.to_ref();
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
let (tmp, _) = scratch.take_slice(
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
module,
res.size(),
a.size(),
pmat.rows(),
pmat.cols_in(),
pmat.cols_out(),
pmat.size(),
) / size_of::<f64>(),
);
vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, limb_offset * pmat.cols_out(), tmp);
}
}
unsafe impl VmpPrepareTmpBytesImpl<Self> for FFT64Avx {
fn vmp_prepare_tmp_bytes_impl(module: &Module<Self>, _rows: usize, _cols_in: usize, _cols_out: usize, _size: usize) -> usize {
vmp_prepare_tmp_bytes(module.n())
}
}
unsafe impl VmpPrepareImpl<Self> for FFT64Avx {
fn vmp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: VmpPMatToMut<Self>,
A: MatZnxToRef,
{
{}
let mut res: VmpPMat<&mut [u8], Self> = res.to_mut();
let a: MatZnx<&[u8]> = a.to_ref();
let (tmp, _) =
scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()) / size_of::<f64>());
vmp_prepare(module.get_fft_table(), &mut res, &a, tmp);
}
}
unsafe impl VmpApplyDftToDftTmpBytesImpl<Self> for FFT64Avx {
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
_module: &Module<Self>,
_res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
_b_cols_out: usize,
_b_size: usize,
) -> usize {
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
}
}
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Avx {
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
_module: &Module<Self>,
_res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
_b_cols_out: usize,
_b_size: usize,
) -> usize {
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
}
}
unsafe impl VmpZeroImpl<Self> for FFT64Avx {
fn vmp_zero_impl<R>(_module: &Module<Self>, res: &mut R)
where
R: VmpPMatToMut<Self>,
{
vmp_zero(res);
}
}

View File

@@ -0,0 +1,74 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_add_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
bb = bb.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_add_ref;
znx_add_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_add_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_add_inplace_ref;
znx_add_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -0,0 +1,132 @@
use core::arch::x86_64::*;
#[inline]
fn inv_mod_pow2(p: usize, bits: u32) -> usize {
// Compute p^{-1} mod 2^bits (p must be odd) through Hensel lifting.
debug_assert!(p % 2 == 1);
let mut x: usize = 1usize; // inverse mod 2
let mut i: u32 = 1;
while i < bits {
// x <- x * (2 - p*x) mod 2^(2^i) (wrapping arithmetic)
x = x.wrapping_mul(2usize.wrapping_sub(p.wrapping_mul(x)));
i <<= 1;
}
x & ((1usize << bits) - 1)
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2", enable = "fma")]
pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) {
debug_assert_eq!(res.len(), a.len());
let n: usize = res.len();
if n == 0 {
return;
}
debug_assert!(n.is_power_of_two(), "n must be power of two");
debug_assert!(p & 1 == 1, "p must be odd (invertible mod 2n)");
if n < 4 {
use poulpy_hal::reference::znx::znx_automorphism_ref;
znx_automorphism_ref(p, res, a);
return;
}
unsafe {
let two_n: usize = n << 1;
let span: usize = n >> 2;
let bits: u32 = (two_n as u64).trailing_zeros();
let mask_2n: usize = two_n - 1;
let mask_1n: usize = n - 1;
// p mod 2n (positive)
let p_2n: usize = (((p & mask_2n as i64) + two_n as i64) as usize) & mask_2n;
// p^-1 mod 2n
let inv: usize = inv_mod_pow2(p_2n, bits);
// Broadcast constants
let n_minus1_vec: __m256i = _mm256_set1_epi64x((n as i64) - 1);
let mask_2n_vec: __m256i = _mm256_set1_epi64x(mask_2n as i64);
let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64);
// Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n)
let lane_offsets: __m256i = _mm256_set_epi64x(
((inv * 3) & mask_2n) as i64,
((inv * 2) & mask_2n) as i64,
inv as i64,
0i64,
);
// t_base = (j * inv) mod 2n.
let mut t_base: usize = 0;
let step: usize = (inv << 2) & mask_2n;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let aa: *const i64 = a.as_ptr();
for _ in 0..span {
// t_vec = (t_base + [0, inv, 2*inv, 3*inv]) & (2n-1)
let t_base_vec: __m256i = _mm256_set1_epi64x(t_base as i64);
let t_vec: __m256i = _mm256_and_si256(_mm256_add_epi64(t_base_vec, lane_offsets), mask_2n_vec);
// idx = t_vec & (n-1)
let idx_vec: __m256i = _mm256_and_si256(t_vec, mask_1n_vec);
// sign = t >= n ? -1 : 0 (mask of all-ones where negate)
let sign_mask: __m256i = _mm256_cmpgt_epi64(t_vec, n_minus1_vec);
// gather a[idx] (scale = 8 bytes per i64)
let vals: __m256i = _mm256_i64gather_epi64(aa, idx_vec, 8);
// Conditional negate: (vals ^ sign_mask) - sign_mask
let vals_x: __m256i = _mm256_xor_si256(vals, sign_mask);
let out: __m256i = _mm256_sub_epi64(vals_x, sign_mask);
// store to res[j..j+4]
_mm256_storeu_si256(rr, out);
// advance
rr = rr.add(1);
t_base = (t_base + step) & mask_2n;
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
mod tests {
use poulpy_hal::reference::znx::znx_automorphism_ref;
use super::*;
#[allow(dead_code)]
#[target_feature(enable = "avx2", enable = "fma")]
fn test_znx_automorphism_internal() {
let a: [i64; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let p: i64 = -5;
let mut r0: Vec<i64> = vec![0i64; a.len()];
let mut r1: Vec<i64> = vec![0i64; a.len()];
znx_automorphism_ref(p, &mut r0, &a);
znx_automorphism_avx(p, &mut r1, &a);
assert_eq!(r0, r1);
}
#[test]
fn test_znx_automorphism_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
test_znx_automorphism_internal();
}
}
}

View File

@@ -0,0 +1,15 @@
mod add;
mod automorphism;
mod mul;
mod neg;
mod normalization;
mod sub;
mod switch_ring;
pub(crate) use add::*;
pub(crate) use automorphism::*;
pub(crate) use mul::*;
pub(crate) use neg::*;
pub(crate) use normalization::*;
pub(crate) use sub::*;
pub(crate) use switch_ring::*;

View File

@@ -0,0 +1,318 @@
/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use poulpy_hal::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(aa);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}
/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) {
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(rr);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
}
/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use crate::znx_avx::znx_add_inplace_avx;
znx_add_inplace_avx(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let out: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -0,0 +1,62 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_negate_avx(res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len())
}
let n: usize = res.len();
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let span: usize = n >> 2;
unsafe {
let mut aa: *const __m256i = src.as_ptr() as *const __m256i;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let v: __m256i = _mm256_loadu_si256(aa);
let neg: __m256i = _mm256_sub_epi64(zero, v);
_mm256_storeu_si256(rr, neg);
aa = aa.add(1);
rr = rr.add(1);
}
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_negate_ref;
znx_negate_ref(&mut res[span << 2..], &src[span << 2..])
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_negate_inplace_avx(res: &mut [i64]) {
let n: usize = res.len();
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let v: __m256i = _mm256_loadu_si256(rr);
let neg: __m256i = _mm256_sub_epi64(zero, v);
_mm256_storeu_si256(rr, neg);
rr = rr.add(1);
}
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_negate_inplace_ref;
znx_negate_inplace_ref(&mut res[span << 2..])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
bb = bb.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ref;
znx_sub_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_inplace_ref;
znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
let n: usize = res.len();
let span: usize = n >> 2;
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
unsafe {
for _ in 0..span {
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(rr));
_mm256_storeu_si256(rr, sum);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref;
znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -0,0 +1,86 @@
#[target_feature(enable = "avx2")]
pub unsafe fn znx_switch_ring_avx(res: &mut [i64], a: &[i64]) {
unsafe {
use core::arch::x86_64::*;
let (n_in, n_out) = (a.len(), res.len());
#[cfg(debug_assertions)]
{
assert!(n_in.is_power_of_two());
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
}
if n_in == n_out {
use poulpy_hal::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
if n_in > n_out {
// Downsample: res[k] = a[k * gap_in], contiguous stores
let gap_in: usize = n_in / n_out;
// index vector: [0*gap, 1*gap, 2*gap, 3*gap] * gap_in
let step: __m256i = _mm256_setr_epi64x(0, gap_in as i64, 2 * gap_in as i64, 3 * gap_in as i64);
let span: usize = n_out >> 2;
let bump: __m256i = _mm256_set1_epi64x(4 * gap_in as i64);
let mut res_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let a_ptr: *const i64 = a.as_ptr();
let mut base: __m256i = _mm256_setzero_si256(); // starts at 0*gap
for _ in 0..span {
// idx = base + step
let idx: __m256i = _mm256_add_epi64(base, step);
// gather 4 spaced i64 (scale=8 bytes)
let v: __m256i = _mm256_i64gather_epi64(a_ptr, idx, 8);
// store contiguously
_mm256_storeu_si256(res_4xi64, v);
base = _mm256_add_epi64(base, bump);
res_4xi64 = res_4xi64.add(1);
}
} else {
// Upsample: res[k * gap_out] = a[k], i.e. res has holes;
use poulpy_hal::reference::znx::znx_zero_ref;
let gap_out = n_out / n_in;
// zero then scatter scalar stores
znx_zero_ref(res);
let mut a_4xi64: *const __m256i = a.as_ptr() as *const __m256i;
for i in (0..n_in).step_by(4) {
// Load contiguously 4 inputs
let v = _mm256_loadu_si256(a_4xi64);
// extract 4 lanes (pextrq). This is still the best we can do on AVX2.
let x0: i64 = _mm256_extract_epi64(v, 0);
let x1: i64 = _mm256_extract_epi64(v, 1);
let x2: i64 = _mm256_extract_epi64(v, 2);
let x3: i64 = _mm256_extract_epi64(v, 3);
// starting output pointer for this group
let mut p: *mut i64 = res.as_mut_ptr().add(i * gap_out);
// four strided stores with pointer bump (avoid mul each time)
*p = x0;
p = p.add(gap_out);
*p = x1;
p = p.add(gap_out);
*p = x2;
p = p.add(gap_out);
*p = x3;
a_4xi64 = a_4xi64.add(1)
}
}
}
}