mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Support for bivariate convolution & normalization with offset (#126)
* Add bivariate-convolution * Add pair-wise convolution + tests + benches * Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md * cross-base2k normalization with positive offset * clippy & fix CI doctest avx compile error * more streamlined bounds derivation for normalization * Working cross-base2k normalization with pos/neg offset * Update normalization API & tests * Add glwe tensoring test * Add relinearization + preliminary test * Fix GGLWEToGGSW key infos * Add (X,Y) convolution by const (1, Y) poly * Faster normalization test + add bench for cnv_by_const * Update changelog
This commit is contained in:
committed by
GitHub
parent
76424d0ab5
commit
4e90e08a71
401
poulpy-cpu-avx/src/convolution.rs
Normal file
401
poulpy-cpu-avx/src/convolution.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
use poulpy_hal::{
|
||||
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
|
||||
layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx,
|
||||
VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
|
||||
},
|
||||
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
|
||||
reference::fft64::convolution::{
|
||||
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
|
||||
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
|
||||
convolution_prepare_right,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{FFT64Avx, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl CnvPVecLAllocImpl<Self> for FFT64Avx {
|
||||
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, Self> {
|
||||
CnvPVecL::alloc(n, cols, size)
|
||||
}
|
||||
|
||||
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, Self> {
|
||||
CnvPVecR::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl CnvPVecBytesOfImpl for FFT64Avx {
|
||||
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
|
||||
}
|
||||
|
||||
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ConvolutionImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
|
||||
{
|
||||
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
|
||||
}
|
||||
|
||||
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: CnvPVecLToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
|
||||
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
|
||||
}
|
||||
|
||||
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
|
||||
}
|
||||
|
||||
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: CnvPVecRToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
|
||||
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
|
||||
}
|
||||
fn cnv_apply_dft_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::<i64>());
|
||||
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
|
||||
}
|
||||
|
||||
fn cnv_apply_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: CnvPVecLToRef<Self>,
|
||||
B: CnvPVecRToRef<Self>,
|
||||
{
|
||||
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
|
||||
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft_tmp_bytes(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
b: &B,
|
||||
i: usize,
|
||||
j: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: CnvPVecLToRef<Self>,
|
||||
B: CnvPVecRToRef<Self>,
|
||||
{
|
||||
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
|
||||
let (tmp, _) = scratch
|
||||
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
|
||||
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2.
|
||||
/// Assumes all inputs fit in i32 (so i32×i32→i64 is exact).
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
use core::arch::x86_64::{
|
||||
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
|
||||
_mm256_storeu_si256,
|
||||
};
|
||||
|
||||
dst.fill(0);
|
||||
|
||||
let b_size = b.len();
|
||||
if k >= a_size + b_size {
|
||||
return;
|
||||
}
|
||||
|
||||
let j_min = k.saturating_sub(a_size - 1);
|
||||
let j_max = (k + 1).min(b_size);
|
||||
|
||||
unsafe {
|
||||
// Two accumulators = 8 outputs total
|
||||
let mut acc_lo: __m256i = _mm256_setzero_si256(); // dst[0..4)
|
||||
let mut acc_hi: __m256i = _mm256_setzero_si256(); // dst[4..8)
|
||||
|
||||
let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min));
|
||||
let mut b_ptr: *const i64 = b.as_ptr().add(j_min);
|
||||
|
||||
for _ in 0..(j_max - j_min) {
|
||||
// Broadcast scalar b[j] as i32
|
||||
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
|
||||
|
||||
// ---- lower half: a[0..4) ----
|
||||
let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i);
|
||||
|
||||
let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br);
|
||||
|
||||
acc_lo = _mm256_add_epi64(acc_lo, prod_lo);
|
||||
|
||||
// ---- upper half: a[4..8) ----
|
||||
let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i);
|
||||
|
||||
let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br);
|
||||
|
||||
acc_hi = _mm256_add_epi64(acc_hi, prod_hi);
|
||||
|
||||
a_ptr = a_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Store final result
|
||||
_mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo);
|
||||
_mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2.
|
||||
/// Assumes all values in `a` and `b` fit in i32 (so i32×i32→i64 is exact).
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn i64_convolution_by_real_const_2coeffs_avx(
|
||||
k: usize,
|
||||
dst: &mut [i64; 16],
|
||||
a: &[i64],
|
||||
a_size: usize,
|
||||
b: &[i64], // real scalars, stride-1
|
||||
) {
|
||||
use core::arch::x86_64::{
|
||||
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
|
||||
_mm256_storeu_si256,
|
||||
};
|
||||
|
||||
let b_size: usize = b.len();
|
||||
|
||||
debug_assert!(a.len() >= 8 * a_size);
|
||||
|
||||
let k0: usize = k;
|
||||
let k1: usize = k + 1;
|
||||
let bound: usize = a_size + b_size;
|
||||
|
||||
if k0 >= bound {
|
||||
unsafe {
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
let dst_ptr: *mut i64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_si256(dst_ptr as *mut __m256i, zero);
|
||||
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero);
|
||||
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero);
|
||||
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc_lo_k0: __m256i = _mm256_setzero_si256();
|
||||
let mut acc_hi_k0: __m256i = _mm256_setzero_si256();
|
||||
let mut acc_lo_k1: __m256i = _mm256_setzero_si256();
|
||||
let mut acc_hi_k1: __m256i = _mm256_setzero_si256();
|
||||
|
||||
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
|
||||
let j0_max: usize = (k0 + 1).min(b_size);
|
||||
|
||||
if k1 >= bound {
|
||||
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
|
||||
|
||||
// Contributions to k0 only
|
||||
for _ in 0..j0_max - j0_min {
|
||||
// Broadcast b[j] as i32
|
||||
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
|
||||
|
||||
// Load 4×i64 (low half) and 4×i64 (high half)
|
||||
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
|
||||
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
|
||||
|
||||
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
|
||||
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
} else {
|
||||
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
|
||||
let j1_max: usize = (k1 + 1).min(b_size);
|
||||
|
||||
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min));
|
||||
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
|
||||
|
||||
// Region 1: k0 only, j ∈ [j0_min, j1_min)
|
||||
for _ in 0..j1_min - j0_min {
|
||||
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
|
||||
|
||||
let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
|
||||
let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
|
||||
|
||||
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br));
|
||||
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br));
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
|
||||
// Save one load on b: broadcast once and reuse.
|
||||
for _ in 0..j0_max - j1_min {
|
||||
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
|
||||
|
||||
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
|
||||
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
|
||||
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
|
||||
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
|
||||
|
||||
// k0
|
||||
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
|
||||
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
|
||||
|
||||
// k1
|
||||
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
|
||||
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Region 3: k1 only, j ∈ [j0_max, j1_max)
|
||||
for _ in 0..j1_max - j0_max {
|
||||
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
|
||||
|
||||
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
|
||||
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
|
||||
|
||||
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
|
||||
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
|
||||
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
let dst_ptr: *mut i64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0);
|
||||
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0);
|
||||
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1);
|
||||
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx")]
|
||||
pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
unsafe {
|
||||
let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; // src + 8*blk
|
||||
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
let step: usize = n >> 2;
|
||||
|
||||
// Each iteration copies 8 i64; advance src by n i64 each row
|
||||
for _ in 0..rows {
|
||||
let v: __m256i = _mm256_loadu_si256(src_ptr);
|
||||
_mm256_storeu_si256(dst_ptr, v);
|
||||
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
|
||||
_mm256_storeu_si256(dst_ptr.add(1), v);
|
||||
dst_ptr = dst_ptr.add(2);
|
||||
src_ptr = src_ptr.add(step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx")]
|
||||
pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
unsafe {
|
||||
let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i;
|
||||
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i; // dst + 8*blk
|
||||
|
||||
let step: usize = n >> 2;
|
||||
|
||||
// Each iteration copies 8 i64; advance dst by n i64 each row
|
||||
for _ in 0..rows {
|
||||
let v: __m256i = _mm256_loadu_si256(src_ptr);
|
||||
_mm256_storeu_si256(dst_ptr, v);
|
||||
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
|
||||
_mm256_storeu_si256(dst_ptr.add(1), v);
|
||||
dst_ptr = dst_ptr.add(step);
|
||||
src_ptr = src_ptr.add(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
// ─────────────────────────────────────────────────────────────
|
||||
// Build the backend **only when ALL conditions are satisfied**
|
||||
// ─────────────────────────────────────────────────────────────
|
||||
#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
|
||||
//#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
|
||||
|
||||
// If the user enables this backend but targets a non-x86_64 CPU → abort
|
||||
#[cfg(all(feature = "enable-avx", not(target_arch = "x86_64")))]
|
||||
@@ -15,6 +15,7 @@ compile_error!("feature `enable-avx` requires AVX2. Build with RUSTFLAGS=\"-C ta
|
||||
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", not(target_feature = "fma")))]
|
||||
compile_error!("feature `enable-avx` requires FMA. Build with RUSTFLAGS=\"-C target-feature=+fma\".");
|
||||
|
||||
mod convolution;
|
||||
mod module;
|
||||
mod reim;
|
||||
mod reim4;
|
||||
|
||||
@@ -5,13 +5,18 @@ use poulpy_hal::{
|
||||
oep::ModuleNewImpl,
|
||||
reference::{
|
||||
fft64::{
|
||||
convolution::{
|
||||
I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous,
|
||||
},
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
|
||||
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
|
||||
Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff,
|
||||
Reim4ConvolutionByRealConst2Coeffs, Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd,
|
||||
Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save1BlkContiguous, Reim4Save2Blks,
|
||||
},
|
||||
},
|
||||
znx::{
|
||||
@@ -26,6 +31,10 @@ use poulpy_hal::{
|
||||
|
||||
use crate::{
|
||||
FFT64Avx,
|
||||
convolution::{
|
||||
i64_convolution_by_const_1coeff_avx, i64_convolution_by_real_const_2coeffs_avx, i64_extract_1blk_contiguous_avx,
|
||||
i64_save_1blk_contiguous_avx,
|
||||
},
|
||||
reim::{
|
||||
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
|
||||
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
|
||||
@@ -33,8 +42,10 @@ use crate::{
|
||||
},
|
||||
reim_to_znx_i64_bnd63_avx2_fma,
|
||||
reim4::{
|
||||
reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx,
|
||||
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
|
||||
reim4_convolution_1coeff_avx, reim4_convolution_2coeffs_avx, reim4_convolution_by_real_const_1coeff_avx,
|
||||
reim4_convolution_by_real_const_2coeffs_avx, reim4_extract_1blk_from_reim_contiguous_avx, reim4_save_1blk_to_reim_avx,
|
||||
reim4_save_1blk_to_reim_contiguous_avx, reim4_save_2blk_to_reim_avx, reim4_vec_mat1col_product_avx,
|
||||
reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
|
||||
},
|
||||
znx_avx::{
|
||||
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
|
||||
@@ -470,11 +481,55 @@ impl ReimZero for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1Blk for FFT64Avx {
|
||||
impl Reim4Convolution1Coeff for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
unsafe {
|
||||
reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src);
|
||||
reim4_convolution_1coeff_avx(k, dst, a, a_size, b, b_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Convolution2Coeffs for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
unsafe {
|
||||
reim4_convolution_2coeffs_avx(k, dst, a, a_size, b, b_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4ConvolutionByRealConst1Coeff for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
unsafe {
|
||||
reim4_convolution_by_real_const_1coeff_avx(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4ConvolutionByRealConst2Coeffs for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
unsafe {
|
||||
reim4_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1BlkContiguous for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
unsafe {
|
||||
reim4_extract_1blk_from_reim_contiguous_avx(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save1BlkContiguous for FFT64Avx {
|
||||
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
unsafe {
|
||||
reim4_save_1blk_to_reim_contiguous_avx(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -523,3 +578,39 @@ impl Reim4Mat2Cols2ndColProd for FFT64Avx {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl I64ConvolutionByConst1Coeff for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
unsafe {
|
||||
i64_convolution_by_const_1coeff_avx(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl I64ConvolutionByConst2Coeffs for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
unsafe {
|
||||
i64_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl I64Save1BlkContiguous for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
i64_save_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl I64Extract1BlkContiguous for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
i64_extract_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,11 +18,7 @@ pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
fft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
fft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg))
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
@@ -70,12 +66,7 @@ fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mu
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
bitwiddle_fft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..]));
|
||||
|
||||
pos += 4;
|
||||
}
|
||||
@@ -232,16 +223,10 @@ fn test_fft_avx2_fma() {
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
@@ -250,14 +235,7 @@ fn test_fft_avx2_fma() {
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,11 +17,7 @@ pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
ifft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
ifft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg))
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
@@ -72,12 +68,7 @@ fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], m
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
inv_bitwiddle_ifft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..]));
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
@@ -225,16 +216,10 @@ fn test_ifft_avx2_fma() {
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
@@ -243,14 +228,7 @@ fn test_ifft_avx2_fma() {
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,10 +32,7 @@ use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
|
||||
|
||||
global_asm!(
|
||||
include_str!("fft16_avx2_fma.s"),
|
||||
include_str!("ifft16_avx2_fma.s")
|
||||
);
|
||||
global_asm!(include_str!("fft16_avx2_fma.s"), include_str!("ifft16_avx2_fma.s"));
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx")]
|
||||
pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
pub fn reim4_extract_1blk_from_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
unsafe {
|
||||
@@ -20,6 +20,28 @@ pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst:
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx")]
|
||||
pub fn reim4_save_1blk_to_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
unsafe {
|
||||
let mut src_ptr: *const __m256d = src.as_ptr() as *const __m256d;
|
||||
let mut dst_ptr: *mut __m256d = dst.as_mut_ptr().add(blk << 2) as *mut __m256d; // dst + 4*blk
|
||||
|
||||
let step: usize = m >> 2;
|
||||
|
||||
// Each iteration copies 4 doubles; advance dst by m doubles each row
|
||||
for _ in 0..2 * rows {
|
||||
let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64);
|
||||
_mm256_storeu_pd(dst_ptr as *mut f64, v);
|
||||
dst_ptr = dst_ptr.add(step);
|
||||
src_ptr = src_ptr.add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
@@ -148,11 +170,7 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64],
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len());
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
@@ -185,16 +203,16 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64],
|
||||
let br: __m256d = _mm256_loadu_pd(v_ptr.add(8));
|
||||
let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12));
|
||||
|
||||
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
|
||||
// re1 = ui*ai - re1; re2 = ui*bi - re2;
|
||||
re1 = _mm256_fmsub_pd(ui, ai, re1);
|
||||
re2 = _mm256_fmsub_pd(ui, bi, re2);
|
||||
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
|
||||
// im1 = ur*ai + im1; im2 = ur*bi + im2;
|
||||
im1 = _mm256_fmadd_pd(ur, ai, im1);
|
||||
im2 = _mm256_fmadd_pd(ur, bi, im2);
|
||||
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
|
||||
// re1 = ur*ar - re1; re2 = ur*br - re2;
|
||||
re1 = _mm256_fmsub_pd(ur, ar, re1);
|
||||
re2 = _mm256_fmsub_pd(ur, br, re2);
|
||||
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
|
||||
// im1 = ui*ar + im1; im2 = ui*br + im2;
|
||||
im1 = _mm256_fmadd_pd(ui, ar, im1);
|
||||
im2 = _mm256_fmadd_pd(ui, br, im2);
|
||||
|
||||
@@ -219,10 +237,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles");
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@@ -239,13 +254,13 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
|
||||
let ar: __m256d = _mm256_loadu_pd(v_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
|
||||
|
||||
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
|
||||
// re1 = ui*ai - re1;
|
||||
re1 = _mm256_fmsub_pd(ui, ai, re1);
|
||||
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
|
||||
// im1 = im1 + ur*ai;
|
||||
im1 = _mm256_fmadd_pd(ur, ai, im1);
|
||||
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
|
||||
// re1 = ur*ar - re1;
|
||||
re1 = _mm256_fmsub_pd(ur, ar, re1);
|
||||
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
|
||||
// im1 = im1 + ui*ar;
|
||||
im1 = _mm256_fmadd_pd(ui, ar, im1);
|
||||
|
||||
u_ptr = u_ptr.add(8);
|
||||
@@ -256,3 +271,360 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub unsafe fn reim4_convolution_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
unsafe {
|
||||
// Scalar guard — same semantics as reference implementation
|
||||
if k >= a_size + b_size {
|
||||
let zero: __m256d = _mm256_setzero_pd();
|
||||
let dst_ptr: *mut f64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), zero);
|
||||
return;
|
||||
}
|
||||
|
||||
let j_min: usize = k.saturating_sub(a_size - 1);
|
||||
let j_max: usize = (k + 1).min(b_size);
|
||||
|
||||
// acc_re = dst[0..4], acc_im = dst[4..8]
|
||||
let mut acc_re: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j_min);
|
||||
|
||||
for _ in 0..j_max - j_min {
|
||||
// Load a[(k - j)]
|
||||
let ar: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4));
|
||||
|
||||
// Load b[j]
|
||||
let br: __m256d = _mm256_loadu_pd(b_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
|
||||
|
||||
// acc_re = ai*bi - acc_re
|
||||
acc_re = _mm256_fmsub_pd(ai, bi, acc_re);
|
||||
// acc_im = ar*bi - acc_im
|
||||
acc_im = _mm256_fmadd_pd(ar, bi, acc_im);
|
||||
// acc_re = ar*br - acc_re
|
||||
acc_re = _mm256_fmsub_pd(ar, br, acc_re);
|
||||
// acc_im = acc_im + ai*br
|
||||
acc_im = _mm256_fmadd_pd(ai, br, acc_im);
|
||||
|
||||
a_ptr = a_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(8);
|
||||
}
|
||||
|
||||
// Store accumulators into dst
|
||||
_mm256_storeu_pd(dst.as_mut_ptr(), acc_re);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub unsafe fn reim4_convolution_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fnmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
debug_assert!(a.len() >= 8 * a_size);
|
||||
debug_assert!(b.len() >= 8 * b_size);
|
||||
|
||||
let k0: usize = k;
|
||||
let k1: usize = k + 1;
|
||||
let bound: usize = a_size + b_size;
|
||||
|
||||
// Since k is a multiple of two, if either k0 or k1 are out of range,
|
||||
// both are.
|
||||
if k0 >= bound {
|
||||
unsafe {
|
||||
let zero: __m256d = _mm256_setzero_pd();
|
||||
let dst_ptr: *mut f64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(8), zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(12), zero);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc_re_k0: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im_k0: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_re_k1: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im_k1: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
|
||||
let j0_max: usize = (k0 + 1).min(b_size);
|
||||
|
||||
if k1 >= bound {
|
||||
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min);
|
||||
|
||||
// Region 1: contributions to k0 only, j ∈ [j0_min, j1_min)
|
||||
for _ in 0..j0_max - j0_min {
|
||||
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let br: __m256d = _mm256_loadu_pd(b_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
|
||||
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
|
||||
acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(8);
|
||||
}
|
||||
} else {
|
||||
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
|
||||
let j1_max: usize = (k1 + 1).min(b_size);
|
||||
|
||||
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min);
|
||||
|
||||
// Region 1: contributions to k0 only, j ∈ [j0_min, j1_min)
|
||||
for _ in 0..j1_min - j0_min {
|
||||
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let br: __m256d = _mm256_loadu_pd(b_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
|
||||
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
|
||||
acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(8);
|
||||
}
|
||||
|
||||
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
|
||||
// We can save one load on b.
|
||||
for _ in 0..j0_max - j1_min {
|
||||
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
|
||||
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
|
||||
let br: __m256d = _mm256_loadu_pd(b_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
|
||||
|
||||
// k0
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
|
||||
acc_re_k0 = _mm256_fnmadd_pd(ai0, bi, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ar0, bi, acc_im_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
|
||||
|
||||
// k1
|
||||
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
|
||||
acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(8);
|
||||
}
|
||||
|
||||
// Region 3: contributions to k1 only, j ∈ [j0_max, j1_max)
|
||||
for _ in 0..j1_max - j0_max {
|
||||
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
|
||||
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
|
||||
let br: __m256d = _mm256_loadu_pd(b_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
|
||||
|
||||
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
|
||||
acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
|
||||
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(8);
|
||||
}
|
||||
}
|
||||
|
||||
// Store both coefficients
|
||||
let dst_ptr = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, acc_re_k0);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), acc_im_k0);
|
||||
_mm256_storeu_pd(dst_ptr.add(8), acc_re_k1);
|
||||
_mm256_storeu_pd(dst_ptr.add(12), acc_im_k1);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub unsafe fn reim4_convolution_by_real_const_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
unsafe {
|
||||
let b_size: usize = b.len();
|
||||
|
||||
if k >= a_size + b_size {
|
||||
let zero: __m256d = _mm256_setzero_pd();
|
||||
let dst_ptr: *mut f64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), zero);
|
||||
return;
|
||||
}
|
||||
|
||||
let j_min: usize = k.saturating_sub(a_size - 1);
|
||||
let j_max: usize = (k + 1).min(b_size);
|
||||
|
||||
// acc_re = dst[0..4], acc_im = dst[4..8]
|
||||
let mut acc_re: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(j_min);
|
||||
|
||||
for _ in 0..j_max - j_min {
|
||||
// Load a[(k - j)]
|
||||
let ar: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4));
|
||||
|
||||
// Load scalar b[j] and broadcast
|
||||
let br: __m256d = _mm256_set1_pd(*b_ptr);
|
||||
|
||||
// Complex * real:
|
||||
// re += ar * br
|
||||
// im += ai * br
|
||||
acc_re = _mm256_fmadd_pd(ar, br, acc_re);
|
||||
acc_im = _mm256_fmadd_pd(ai, br, acc_im);
|
||||
|
||||
a_ptr = a_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Store accumulators into dst
|
||||
_mm256_storeu_pd(dst.as_mut_ptr(), acc_re);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub unsafe fn reim4_convolution_by_real_const_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
let b_size: usize = b.len();
|
||||
|
||||
debug_assert!(a.len() >= 8 * a_size);
|
||||
|
||||
let k0: usize = k;
|
||||
let k1: usize = k + 1;
|
||||
let bound: usize = a_size + b_size;
|
||||
|
||||
// Since k is a multiple of two, if either k0 or k1 are out of range,
|
||||
// both are.
|
||||
if k0 >= bound {
|
||||
unsafe {
|
||||
let zero: __m256d = _mm256_setzero_pd();
|
||||
let dst_ptr: *mut f64 = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(8), zero);
|
||||
_mm256_storeu_pd(dst_ptr.add(12), zero);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc_re_k0: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im_k0: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_re_k1: __m256d = _mm256_setzero_pd();
|
||||
let mut acc_im_k1: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
|
||||
let j0_max: usize = (k0 + 1).min(b_size);
|
||||
|
||||
if k1 >= bound {
|
||||
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(j0_min);
|
||||
|
||||
// Contributions to k0 only
|
||||
for _ in 0..j0_max - j0_min {
|
||||
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let br: __m256d = _mm256_set1_pd(*b_ptr);
|
||||
|
||||
// complex * real
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
} else {
|
||||
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
|
||||
let j1_max: usize = (k1 + 1).min(b_size);
|
||||
|
||||
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
|
||||
let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min));
|
||||
let mut b_ptr: *const f64 = b.as_ptr().add(j0_min);
|
||||
|
||||
// Region 1: k0 only, j ∈ [j0_min, j1_min)
|
||||
for _ in 0..j1_min - j0_min {
|
||||
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let br: __m256d = _mm256_set1_pd(*b_ptr);
|
||||
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
|
||||
// Still “save one load on b”: we broadcast once and reuse.
|
||||
for _ in 0..j0_max - j1_min {
|
||||
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
|
||||
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
|
||||
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
|
||||
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
|
||||
let br: __m256d = _mm256_set1_pd(*b_ptr);
|
||||
|
||||
// k0
|
||||
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
|
||||
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
|
||||
|
||||
// k1
|
||||
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
|
||||
|
||||
a_k0_ptr = a_k0_ptr.sub(8);
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
|
||||
// Region 3: k1 only, j ∈ [j0_max, j1_max)
|
||||
for _ in 0..j1_max - j0_max {
|
||||
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
|
||||
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
|
||||
let br: __m256d = _mm256_set1_pd(*b_ptr);
|
||||
|
||||
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
|
||||
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
|
||||
|
||||
a_k1_ptr = a_k1_ptr.sub(8);
|
||||
b_ptr = b_ptr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Store both coefficients
|
||||
let dst_ptr = dst.as_mut_ptr();
|
||||
_mm256_storeu_pd(dst_ptr, acc_re_k0);
|
||||
_mm256_storeu_pd(dst_ptr.add(4), acc_im_k0);
|
||||
_mm256_storeu_pd(dst_ptr.add(8), acc_re_k1);
|
||||
_mm256_storeu_pd(dst_ptr.add(12), acc_im_k1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
|
||||
use poulpy_hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise},
|
||||
};
|
||||
|
||||
use crate::FFT64Avx;
|
||||
|
||||
@@ -119,7 +123,19 @@ mod poulpy_cpu_avx {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convolution_fft64_avx() {
|
||||
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
|
||||
test_bivariate_tensoring(&module);
|
||||
fn test_convolution_by_const_fft64_avx() {
|
||||
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
|
||||
test_convolution_by_const(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convolution_fft64_avx() {
|
||||
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
|
||||
test_convolution(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convolution_pairwise_fft64_avx() {
|
||||
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
|
||||
test_convolution_pairwise(&module);
|
||||
}
|
||||
|
||||
@@ -53,11 +53,12 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_base2k: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
@@ -65,7 +66,7 @@ where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(res_base2k, res, res_col, a_base2k, a, a_col, carry);
|
||||
vec_znx_normalize::<R, A, Self>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ use poulpy_hal::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx {
|
||||
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
@@ -280,11 +280,12 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
@@ -292,7 +293,7 @@ where
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,12 +53,8 @@ pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64);
|
||||
|
||||
// Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n)
|
||||
let lane_offsets: __m256i = _mm256_set_epi64x(
|
||||
((inv * 3) & mask_2n) as i64,
|
||||
((inv * 2) & mask_2n) as i64,
|
||||
inv as i64,
|
||||
0i64,
|
||||
);
|
||||
let lane_offsets: __m256i =
|
||||
_mm256_set_epi64x(((inv * 3) & mask_2n) as i64, ((inv * 2) & mask_2n) as i64, inv as i64, 0i64);
|
||||
|
||||
// t_base = (j * inv) mod 2n.
|
||||
let mut t_base: usize = 0;
|
||||
|
||||
@@ -82,14 +82,14 @@ pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64],
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
// load source & extract digit/carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(sv, digit_256, base2k_vec, top_mask);
|
||||
|
||||
// res += (digit << lsh)
|
||||
let rv: __m256i = _mm256_loadu_si256(rr);
|
||||
@@ -135,7 +135,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64])
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// Constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
// Load res lane
|
||||
@@ -143,7 +143,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64])
|
||||
|
||||
// Extract digit and carry from res
|
||||
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(rv, digit_256, base2k_vec, top_mask);
|
||||
|
||||
// src += carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
@@ -187,7 +187,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i
|
||||
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
|
||||
let (mask, sign, base2k_vec, top_mask) = if lsh == 0 {
|
||||
normalize_consts_avx(base2k)
|
||||
} else {
|
||||
normalize_consts_avx(base2k - lsh)
|
||||
@@ -200,7 +200,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
|
||||
@@ -239,7 +239,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
@@ -248,7 +248,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -257,7 +257,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
|
||||
cc = cc.add(1);
|
||||
}
|
||||
} else {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -268,7 +268,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -311,7 +311,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
@@ -320,7 +320,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -332,7 +332,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -343,7 +343,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -359,13 +359,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
|
||||
|
||||
znx_normalize_first_step_ref(
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
&mut carry[span << 2..],
|
||||
);
|
||||
znx_normalize_first_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,7 +380,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -398,11 +392,11 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(xx, x1);
|
||||
@@ -414,7 +408,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -423,13 +417,13 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(xx, x1);
|
||||
@@ -465,7 +459,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
|
||||
@@ -477,11 +471,11 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(cc, cout);
|
||||
@@ -492,7 +486,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -501,13 +495,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(cc, cout);
|
||||
@@ -543,7 +537,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -556,11 +550,11 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(av, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
|
||||
let c0: __m256i = get_carry_avx(av, d0, base2k_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(xx, x1);
|
||||
@@ -573,7 +567,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -582,13 +576,13 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let c0: __m256i = get_carry_avx(av, d0, base2k_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
|
||||
_mm256_storeu_si256(xx, x1);
|
||||
@@ -604,13 +598,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
|
||||
|
||||
znx_normalize_middle_step_ref(
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
&mut carry[span << 2..],
|
||||
);
|
||||
znx_normalize_middle_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -753,13 +741,7 @@ pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
|
||||
|
||||
znx_normalize_final_step_ref(
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
&mut carry[span << 2..],
|
||||
);
|
||||
znx_normalize_final_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -832,8 +814,8 @@ mod tests {
|
||||
unsafe {
|
||||
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
|
||||
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
|
||||
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
|
||||
let (_, _, base2k_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_carry_avx(x_256, d_256, base2k_vec, top_mask);
|
||||
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
|
||||
}
|
||||
assert_eq!(y0, y1);
|
||||
|
||||
Reference in New Issue
Block a user