Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -32,5 +32,5 @@ rustdoc-args = ["--cfg", "docsrs"]
[[bench]]
name = "vmp"
name = "convolution"
harness = false

View File

@@ -0,0 +1,116 @@
use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_prepare_left_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_prepare_left_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_prepare_left::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_prepare_right_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_prepare_right_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_prepare_right::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_apply_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_apply_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_apply_dft::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_pairwise_apply_dft::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_by_const_apply_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_by_const_apply_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_by_const_apply::<FFT64Avx>(c, "cpu_avx::fft64");
}
criterion_group!(
benches,
bench_cnv_prepare_left_cpu_avx_fft64,
bench_cnv_prepare_right_cpu_avx_fft64,
bench_cnv_apply_dft_cpu_avx_fft64,
bench_cnv_pairwise_apply_dft_cpu_avx_fft64,
bench_cnv_by_const_apply_cpu_avx_fft64,
);
criterion_main!(benches);

View File

@@ -1,11 +1,21 @@
use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_ifft_avx2_fma(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
use criterion::BenchmarkId;
use poulpy_cpu_avx::ReimIFFTAvx;
@@ -21,10 +31,7 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
let mut values: Vec<f64> = vec![0f64; m << 1];
let scale = 1.0f64 / (2 * m) as f64;
values
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
move || {
@@ -47,12 +54,22 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
group.finish();
}
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_fft_avx2_fma(_c: &mut Criterion) {
eprintln!("Skipping: AVX FFT benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
pub fn bench_fft_avx2_fma(c: &mut Criterion) {
use criterion::BenchmarkId;
use poulpy_cpu_avx::ReimFFTAvx;
@@ -68,10 +85,7 @@ pub fn bench_fft_avx2_fma(c: &mut Criterion) {
let mut values: Vec<f64> = vec![0f64; m << 1];
let scale = 1.0f64 / (2 * m) as f64;
values
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
move || {

View File

@@ -1,33 +1,63 @@
use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_add_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_add_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_add::<FFT64Avx>(c, "FFT64Avx");
}
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_normalize_inplace::<FFT64Avx>(c, "FFT64Avx");
}
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_automorphism_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_automorphism_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_automorphism::<FFT64Avx>(c, "FFT64Avx");

View File

@@ -1,11 +1,21 @@
use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft::<FFT64Avx>(c, "FFT64Avx");

View File

@@ -1,8 +1,18 @@
use itertools::izip;
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
use poulpy_cpu_avx::FFT64Avx as BackendImpl;
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))]
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
use poulpy_cpu_ref::FFT64Ref as BackendImpl;
use poulpy_hal::{
@@ -73,8 +83,7 @@ fn main() {
msg_size, // Number of small polynomials
);
let mut want: Vec<i64> = vec![0; n];
want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
want.iter_mut().for_each(|x| *x = source.next_u64n(16, 15) as i64);
m.encode_vec_i64(base2k, 0, log_scale, &want);
module.vec_znx_normalize_inplace(base2k, &mut m, 0, scratch.borrow());
@@ -89,11 +98,12 @@ fn main() {
// Normalizes back to VecZnx
// ct[0] <- m - BIG(c1 * s)
module.vec_znx_big_normalize(
base2k,
&mut ct,
0, // Selects the first column of ct (ct[0])
base2k,
0,
0, // Selects the first column of ct (ct[0])
&buf_big,
base2k,
0, // Selects the first column of buf_big
scratch.borrow(),
);
@@ -131,15 +141,13 @@ fn main() {
// m + e <- BIG(ct[1] * s + ct[0])
let mut res = VecZnx::alloc(module.n(), 1, ct_size);
module.vec_znx_big_normalize(base2k, &mut res, 0, base2k, &buf_big, 0, scratch.borrow());
module.vec_znx_big_normalize(&mut res, base2k, 0, 0, &buf_big, base2k, 0, scratch.borrow());
// have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(base2k, 0, ct_size * base2k, &mut have);
let scale: f64 = (1 << (res.size() * base2k - log_scale)) as f64;
izip!(want.iter(), have.iter())
.enumerate()
.for_each(|(i, (a, b))| {
println!("{}: {} {}", i, a, (*b as f64) / scale);
});
izip!(want.iter(), have.iter()).enumerate().for_each(|(i, (a, b))| {
println!("{}: {} {}", i, a, (*b as f64) / scale);
});
}

View File

@@ -0,0 +1,401 @@
use poulpy_hal::{
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx,
VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
},
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
reference::fft64::convolution::{
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
convolution_prepare_right,
},
};
use crate::{FFT64Avx, module::FFT64ModuleHandle};
unsafe impl CnvPVecLAllocImpl<Self> for FFT64Avx {
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, Self> {
CnvPVecL::alloc(n, cols, size)
}
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, Self> {
CnvPVecR::alloc(n, cols, size)
}
}
unsafe impl CnvPVecBytesOfImpl for FFT64Avx {
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
}
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
}
}
unsafe impl ConvolutionImpl<Self> for FFT64Avx
where
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
{
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecLToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecRToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_apply_dft_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::<i64>());
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
}
fn cnv_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
}
fn cnv_pairwise_apply_dft_tmp_bytes(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_pairwise_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
i: usize,
j: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) = scratch
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all inputs fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
dst.fill(0);
let b_size = b.len();
if k >= a_size + b_size {
return;
}
let j_min = k.saturating_sub(a_size - 1);
let j_max = (k + 1).min(b_size);
unsafe {
// Two accumulators = 8 outputs total
let mut acc_lo: __m256i = _mm256_setzero_si256(); // dst[0..4)
let mut acc_hi: __m256i = _mm256_setzero_si256(); // dst[4..8)
let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j_min);
for _ in 0..(j_max - j_min) {
// Broadcast scalar b[j] as i32
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
// ---- lower half: a[0..4) ----
let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i);
let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br);
acc_lo = _mm256_add_epi64(acc_lo, prod_lo);
// ---- upper half: a[4..8) ----
let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i);
let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br);
acc_hi = _mm256_add_epi64(acc_hi, prod_hi);
a_ptr = a_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Store final result
_mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo);
_mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all values in `a` and `b` fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_real_const_2coeffs_avx(
k: usize,
dst: &mut [i64; 16],
a: &[i64],
a_size: usize,
b: &[i64], // real scalars, stride-1
) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
let b_size: usize = b.len();
debug_assert!(a.len() >= 8 * a_size);
let k0: usize = k;
let k1: usize = k + 1;
let bound: usize = a_size + b_size;
if k0 >= bound {
unsafe {
let zero: __m256i = _mm256_setzero_si256();
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero);
}
return;
}
unsafe {
let mut acc_lo_k0: __m256i = _mm256_setzero_si256();
let mut acc_hi_k0: __m256i = _mm256_setzero_si256();
let mut acc_lo_k1: __m256i = _mm256_setzero_si256();
let mut acc_hi_k1: __m256i = _mm256_setzero_si256();
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
let j0_max: usize = (k0 + 1).min(b_size);
if k1 >= bound {
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
// Contributions to k0 only
for _ in 0..j0_max - j0_min {
// Broadcast b[j] as i32
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
// Load 4×i64 (low half) and 4×i64 (high half)
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
} else {
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
let j1_max: usize = (k1 + 1).min(b_size);
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
// Region 1: k0 only, j ∈ [j0_min, j1_min)
for _ in 0..j1_min - j0_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
// Save one load on b: broadcast once and reuse.
for _ in 0..j0_max - j1_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
// k0
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
// k1
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k0_ptr = a_k0_ptr.sub(8);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 3: k1 only, j ∈ [j0_max, j1_max)
for _ in 0..j1_max - j0_max {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
}
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; // src + 8*blk
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i;
let step: usize = n >> 2;
// Each iteration copies 8 i64; advance src by n i64 each row
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(2);
src_ptr = src_ptr.add(step);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i;
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i; // dst + 8*blk
let step: usize = n >> 2;
// Each iteration copies 8 i64; advance dst by n i64 each row
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(step);
src_ptr = src_ptr.add(2);
}
}
}

View File

@@ -1,7 +1,7 @@
// ─────────────────────────────────────────────────────────────
// Build the backend **only when ALL conditions are satisfied**
// ─────────────────────────────────────────────────────────────
#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
//#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
// If the user enables this backend but targets a non-x86_64 CPU → abort
#[cfg(all(feature = "enable-avx", not(target_arch = "x86_64")))]
@@ -15,6 +15,7 @@ compile_error!("feature `enable-avx` requires AVX2. Build with RUSTFLAGS=\"-C ta
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", not(target_feature = "fma")))]
compile_error!("feature `enable-avx` requires FMA. Build with RUSTFLAGS=\"-C target-feature=+fma\".");
mod convolution;
mod module;
mod reim;
mod reim4;

View File

@@ -5,13 +5,18 @@ use poulpy_hal::{
oep::ModuleNewImpl,
reference::{
fft64::{
convolution::{
I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous,
},
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
},
reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff,
Reim4ConvolutionByRealConst2Coeffs, Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd,
Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save1BlkContiguous, Reim4Save2Blks,
},
},
znx::{
@@ -26,6 +31,10 @@ use poulpy_hal::{
use crate::{
FFT64Avx,
convolution::{
i64_convolution_by_const_1coeff_avx, i64_convolution_by_real_const_2coeffs_avx, i64_extract_1blk_contiguous_avx,
i64_save_1blk_contiguous_avx,
},
reim::{
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
@@ -33,8 +42,10 @@ use crate::{
},
reim_to_znx_i64_bnd63_avx2_fma,
reim4::{
reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx,
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
reim4_convolution_1coeff_avx, reim4_convolution_2coeffs_avx, reim4_convolution_by_real_const_1coeff_avx,
reim4_convolution_by_real_const_2coeffs_avx, reim4_extract_1blk_from_reim_contiguous_avx, reim4_save_1blk_to_reim_avx,
reim4_save_1blk_to_reim_contiguous_avx, reim4_save_2blk_to_reim_avx, reim4_vec_mat1col_product_avx,
reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
},
znx_avx::{
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
@@ -470,11 +481,55 @@ impl ReimZero for FFT64Avx {
}
}
impl Reim4Extract1Blk for FFT64Avx {
impl Reim4Convolution1Coeff for FFT64Avx {
#[inline(always)]
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
unsafe {
reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src);
reim4_convolution_1coeff_avx(k, dst, a, a_size, b, b_size);
}
}
}
impl Reim4Convolution2Coeffs for FFT64Avx {
#[inline(always)]
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
unsafe {
reim4_convolution_2coeffs_avx(k, dst, a, a_size, b, b_size);
}
}
}
impl Reim4ConvolutionByRealConst1Coeff for FFT64Avx {
#[inline(always)]
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
unsafe {
reim4_convolution_by_real_const_1coeff_avx(k, dst, a, a_size, b);
}
}
}
impl Reim4ConvolutionByRealConst2Coeffs for FFT64Avx {
#[inline(always)]
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
unsafe {
reim4_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
}
}
}
impl Reim4Extract1BlkContiguous for FFT64Avx {
#[inline(always)]
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_extract_1blk_from_reim_contiguous_avx(m, rows, blk, dst, src);
}
}
}
impl Reim4Save1BlkContiguous for FFT64Avx {
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_save_1blk_to_reim_contiguous_avx(m, rows, blk, dst, src);
}
}
}
@@ -523,3 +578,39 @@ impl Reim4Mat2Cols2ndColProd for FFT64Avx {
}
}
}
impl I64ConvolutionByConst1Coeff for FFT64Avx {
#[inline(always)]
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
unsafe {
i64_convolution_by_const_1coeff_avx(k, dst, a, a_size, b);
}
}
}
impl I64ConvolutionByConst2Coeffs for FFT64Avx {
#[inline(always)]
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
unsafe {
i64_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
}
}
}
impl I64Save1BlkContiguous for FFT64Avx {
#[inline(always)]
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
unsafe {
i64_save_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
}
}
}
impl I64Extract1BlkContiguous for FFT64Avx {
#[inline(always)]
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
unsafe {
i64_extract_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
}
}
}

View File

@@ -18,11 +18,7 @@ pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
let (re, im) = data.split_at_mut(m);
if m == 16 {
fft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
fft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg))
} else if m <= 2048 {
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
@@ -70,12 +66,7 @@ fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mu
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
bitwiddle_fft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..]));
pos += 4;
}
@@ -232,16 +223,10 @@ fn test_fft_avx2_fma() {
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x);
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
@@ -250,14 +235,7 @@ fn test_fft_avx2_fma() {
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff)
}
}

View File

@@ -17,11 +17,7 @@ pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
let (re, im) = data.split_at_mut(m);
if m == 16 {
ifft16_avx2_fma(
as_arr_mut::<16, f64>(re),
as_arr_mut::<16, f64>(im),
as_arr::<16, f64>(omg),
)
ifft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg))
} else if m <= 2048 {
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
} else {
@@ -72,12 +68,7 @@ fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], m
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_avx2_fma(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, f64>(&omg[pos..]),
);
inv_bitwiddle_ifft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..]));
pos += 4;
}
h = mm;
@@ -225,16 +216,10 @@ fn test_ifft_avx2_fma() {
let mut values_0: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / m as f64;
values_0
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let mut values_1: Vec<f64> = vec![0f64; m << 1];
values_1
.iter_mut()
.zip(values_0.iter())
.for_each(|(y, x)| *y = *x);
values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x);
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
@@ -243,14 +228,7 @@ fn test_ifft_avx2_fma() {
for i in 0..m * 2 {
let diff: f64 = (values_0[i] - values_1[i]).abs();
assert!(
diff <= max_diff,
"{} -> {}-{} = {}",
i,
values_0[i],
values_1[i],
diff
)
assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff)
}
}

View File

@@ -32,10 +32,7 @@ use rand_distr::num_traits::{Float, FloatConst};
use crate::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
global_asm!(
include_str!("fft16_avx2_fma.s"),
include_str!("ifft16_avx2_fma.s")
);
global_asm!(include_str!("fft16_avx2_fma.s"), include_str!("ifft16_avx2_fma.s"));
#[inline(always)]
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {

View File

@@ -1,7 +1,7 @@
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
pub fn reim4_extract_1blk_from_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
unsafe {
@@ -20,6 +20,28 @@ pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst:
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn reim4_save_1blk_to_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
unsafe {
let mut src_ptr: *const __m256d = src.as_ptr() as *const __m256d;
let mut dst_ptr: *mut __m256d = dst.as_mut_ptr().add(blk << 2) as *mut __m256d; // dst + 4*blk
let step: usize = m >> 2;
// Each iteration copies 4 doubles; advance dst by m doubles each row
for _ in 0..2 * rows {
let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64);
_mm256_storeu_pd(dst_ptr as *mut f64, v);
dst_ptr = dst_ptr.add(step);
src_ptr = src_ptr.add(1);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
@@ -148,11 +170,7 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64],
#[cfg(debug_assertions)]
{
assert!(
dst.len() >= 8,
"dst must be at least 8 doubles but is {}",
dst.len()
);
assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len());
assert!(
u.len() >= nrows * 8,
"u must be at least nrows={} * 8 doubles but is {}",
@@ -185,16 +203,16 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64],
let br: __m256d = _mm256_loadu_pd(v_ptr.add(8));
let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12));
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
// re1 = ui*ai - re1; re2 = ui*bi - re2;
re1 = _mm256_fmsub_pd(ui, ai, re1);
re2 = _mm256_fmsub_pd(ui, bi, re2);
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
// im1 = ur*ai + im1; im2 = ur*bi + im2;
im1 = _mm256_fmadd_pd(ur, ai, im1);
im2 = _mm256_fmadd_pd(ur, bi, im2);
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
// re1 = ur*ar - re1; re2 = ur*br - re2;
re1 = _mm256_fmsub_pd(ur, ar, re1);
re2 = _mm256_fmsub_pd(ur, br, re2);
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
// im1 = ui*ar + im1; im2 = ui*br + im2;
im1 = _mm256_fmadd_pd(ui, ar, im1);
im2 = _mm256_fmadd_pd(ui, br, im2);
@@ -219,10 +237,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
{
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(
v.len() >= nrows * 16,
"v must be at least nrows * 16 doubles"
);
assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles");
}
unsafe {
@@ -239,13 +254,13 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
let ar: __m256d = _mm256_loadu_pd(v_ptr);
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
// re1 = ui*ai - re1;
re1 = _mm256_fmsub_pd(ui, ai, re1);
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
// im1 = im1 + ur*ai;
im1 = _mm256_fmadd_pd(ur, ai, im1);
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
// re1 = ur*ar - re1;
re1 = _mm256_fmsub_pd(ur, ar, re1);
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
// im1 = im1 + ui*ar;
im1 = _mm256_fmadd_pd(ui, ar, im1);
u_ptr = u_ptr.add(8);
@@ -256,3 +271,360 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn reim4_convolution_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
unsafe {
// Scalar guard — same semantics as reference implementation
if k >= a_size + b_size {
let zero: __m256d = _mm256_setzero_pd();
let dst_ptr: *mut f64 = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, zero);
_mm256_storeu_pd(dst_ptr.add(4), zero);
return;
}
let j_min: usize = k.saturating_sub(a_size - 1);
let j_max: usize = (k + 1).min(b_size);
// acc_re = dst[0..4], acc_im = dst[4..8]
let mut acc_re: __m256d = _mm256_setzero_pd();
let mut acc_im: __m256d = _mm256_setzero_pd();
let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min));
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j_min);
for _ in 0..j_max - j_min {
// Load a[(k - j)]
let ar: __m256d = _mm256_loadu_pd(a_ptr);
let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4));
// Load b[j]
let br: __m256d = _mm256_loadu_pd(b_ptr);
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
// acc_re = ai*bi - acc_re
acc_re = _mm256_fmsub_pd(ai, bi, acc_re);
// acc_im = ar*bi - acc_im
acc_im = _mm256_fmadd_pd(ar, bi, acc_im);
// acc_re = ar*br - acc_re
acc_re = _mm256_fmsub_pd(ar, br, acc_re);
// acc_im = acc_im + ai*br
acc_im = _mm256_fmadd_pd(ai, br, acc_im);
a_ptr = a_ptr.sub(8);
b_ptr = b_ptr.add(8);
}
// Store accumulators into dst
_mm256_storeu_pd(dst.as_mut_ptr(), acc_re);
_mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn reim4_convolution_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fnmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
debug_assert!(a.len() >= 8 * a_size);
debug_assert!(b.len() >= 8 * b_size);
let k0: usize = k;
let k1: usize = k + 1;
let bound: usize = a_size + b_size;
// Since k is a multiple of two, if either k0 or k1 are out of range,
// both are.
if k0 >= bound {
unsafe {
let zero: __m256d = _mm256_setzero_pd();
let dst_ptr: *mut f64 = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, zero);
_mm256_storeu_pd(dst_ptr.add(4), zero);
_mm256_storeu_pd(dst_ptr.add(8), zero);
_mm256_storeu_pd(dst_ptr.add(12), zero);
}
return;
}
unsafe {
let mut acc_re_k0: __m256d = _mm256_setzero_pd();
let mut acc_im_k0: __m256d = _mm256_setzero_pd();
let mut acc_re_k1: __m256d = _mm256_setzero_pd();
let mut acc_im_k1: __m256d = _mm256_setzero_pd();
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
let j0_max: usize = (k0 + 1).min(b_size);
if k1 >= bound {
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min);
// Region 1: contributions to k0 only, j ∈ [j0_min, j1_min)
for _ in 0..j0_max - j0_min {
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let br: __m256d = _mm256_loadu_pd(b_ptr);
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0);
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(8);
}
} else {
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
let j1_max: usize = (k1 + 1).min(b_size);
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min));
let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min);
// Region 1: contributions to k0 only, j ∈ [j0_min, j1_min)
for _ in 0..j1_min - j0_min {
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let br: __m256d = _mm256_loadu_pd(b_ptr);
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0);
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(8);
}
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
// We can save one load on b.
for _ in 0..j0_max - j1_min {
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
let br: __m256d = _mm256_loadu_pd(b_ptr);
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
// k0
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
acc_re_k0 = _mm256_fnmadd_pd(ai0, bi, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ar0, bi, acc_im_k0);
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
// k1
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1);
acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1);
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
a_k0_ptr = a_k0_ptr.sub(8);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(8);
}
// Region 3: contributions to k1 only, j ∈ [j0_max, j1_max)
for _ in 0..j1_max - j0_max {
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
let br: __m256d = _mm256_loadu_pd(b_ptr);
let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4));
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1);
acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1);
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(8);
}
}
// Store both coefficients
let dst_ptr = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, acc_re_k0);
_mm256_storeu_pd(dst_ptr.add(4), acc_im_k0);
_mm256_storeu_pd(dst_ptr.add(8), acc_re_k1);
_mm256_storeu_pd(dst_ptr.add(12), acc_im_k1);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn reim4_convolution_by_real_const_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd};
unsafe {
let b_size: usize = b.len();
if k >= a_size + b_size {
let zero: __m256d = _mm256_setzero_pd();
let dst_ptr: *mut f64 = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, zero);
_mm256_storeu_pd(dst_ptr.add(4), zero);
return;
}
let j_min: usize = k.saturating_sub(a_size - 1);
let j_max: usize = (k + 1).min(b_size);
// acc_re = dst[0..4], acc_im = dst[4..8]
let mut acc_re: __m256d = _mm256_setzero_pd();
let mut acc_im: __m256d = _mm256_setzero_pd();
let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min));
let mut b_ptr: *const f64 = b.as_ptr().add(j_min);
for _ in 0..j_max - j_min {
// Load a[(k - j)]
let ar: __m256d = _mm256_loadu_pd(a_ptr);
let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4));
// Load scalar b[j] and broadcast
let br: __m256d = _mm256_set1_pd(*b_ptr);
// Complex * real:
// re += ar * br
// im += ai * br
acc_re = _mm256_fmadd_pd(ar, br, acc_re);
acc_im = _mm256_fmadd_pd(ai, br, acc_im);
a_ptr = a_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Store accumulators into dst
_mm256_storeu_pd(dst.as_mut_ptr(), acc_re);
_mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`).
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn reim4_convolution_by_real_const_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd};
let b_size: usize = b.len();
debug_assert!(a.len() >= 8 * a_size);
let k0: usize = k;
let k1: usize = k + 1;
let bound: usize = a_size + b_size;
// Since k is a multiple of two, if either k0 or k1 are out of range,
// both are.
if k0 >= bound {
unsafe {
let zero: __m256d = _mm256_setzero_pd();
let dst_ptr: *mut f64 = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, zero);
_mm256_storeu_pd(dst_ptr.add(4), zero);
_mm256_storeu_pd(dst_ptr.add(8), zero);
_mm256_storeu_pd(dst_ptr.add(12), zero);
}
return;
}
unsafe {
let mut acc_re_k0: __m256d = _mm256_setzero_pd();
let mut acc_im_k0: __m256d = _mm256_setzero_pd();
let mut acc_re_k1: __m256d = _mm256_setzero_pd();
let mut acc_im_k1: __m256d = _mm256_setzero_pd();
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
let j0_max: usize = (k0 + 1).min(b_size);
if k1 >= bound {
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut b_ptr: *const f64 = b.as_ptr().add(j0_min);
// Contributions to k0 only
for _ in 0..j0_max - j0_min {
let ar: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let br: __m256d = _mm256_set1_pd(*b_ptr);
// complex * real
acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0);
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
} else {
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
let j1_max: usize = (k1 + 1).min(b_size);
let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min));
let mut b_ptr: *const f64 = b.as_ptr().add(j0_min);
// Region 1: k0 only, j ∈ [j0_min, j1_min)
for _ in 0..j1_min - j0_min {
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let br: __m256d = _mm256_set1_pd(*b_ptr);
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
// Still “save one load on b”: we broadcast once and reuse.
for _ in 0..j0_max - j1_min {
let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr);
let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4));
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
let br: __m256d = _mm256_set1_pd(*b_ptr);
// k0
acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0);
acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0);
// k1
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
a_k0_ptr = a_k0_ptr.sub(8);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 3: k1 only, j ∈ [j0_max, j1_max)
for _ in 0..j1_max - j0_max {
let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr);
let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4));
let br: __m256d = _mm256_set1_pd(*b_ptr);
acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1);
acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
}
// Store both coefficients
let dst_ptr = dst.as_mut_ptr();
_mm256_storeu_pd(dst_ptr, acc_re_k0);
_mm256_storeu_pd(dst_ptr.add(4), acc_im_k0);
_mm256_storeu_pd(dst_ptr.add(8), acc_re_k1);
_mm256_storeu_pd(dst_ptr.add(12), acc_im_k1);
}
}

View File

@@ -1,4 +1,8 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
use poulpy_hal::{
api::ModuleNew,
layouts::Module,
test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise},
};
use crate::FFT64Avx;
@@ -119,7 +123,19 @@ mod poulpy_cpu_avx {
}
#[test]
fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
test_bivariate_tensoring(&module);
fn test_convolution_by_const_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
test_convolution_by_const(&module);
}
#[test]
fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
test_convolution(&module);
}
#[test]
fn test_convolution_pairwise_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(8);
test_convolution_pairwise(&module);
}

View File

@@ -53,11 +53,12 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
res_base2k: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_base2k: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
@@ -65,7 +66,7 @@ where
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(res_base2k, res, res_col, a_base2k, a, a_col, carry);
vec_znx_normalize::<R, A, Self>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
}
}

View File

@@ -26,7 +26,7 @@ use poulpy_hal::{
source::Source,
};
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx {
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
}
@@ -280,11 +280,12 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
@@ -292,7 +293,7 @@ where
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
}
}

View File

@@ -53,12 +53,8 @@ pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) {
let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64);
// Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n)
let lane_offsets: __m256i = _mm256_set_epi64x(
((inv * 3) & mask_2n) as i64,
((inv * 2) & mask_2n) as i64,
inv as i64,
0i64,
);
let lane_offsets: __m256i =
_mm256_set_epi64x(((inv * 3) & mask_2n) as i64, ((inv * 2) & mask_2n) as i64, inv as i64, 0i64);
// t_base = (j * inv) mod 2n.
let mut t_base: usize = 0;

View File

@@ -82,14 +82,14 @@ pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64],
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
// load source & extract digit/carry
let sv: __m256i = _mm256_loadu_si256(ss);
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(sv, digit_256, base2k_vec, top_mask);
// res += (digit << lsh)
let rv: __m256i = _mm256_loadu_si256(rr);
@@ -135,7 +135,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64])
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// Constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
// Load res lane
@@ -143,7 +143,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64])
// Extract digit and carry from res
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(rv, digit_256, base2k_vec, top_mask);
// src += carry
let sv: __m256i = _mm256_loadu_si256(ss);
@@ -187,7 +187,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
let (mask, sign, base2k_vec, top_mask) = if lsh == 0 {
normalize_consts_avx(base2k)
} else {
normalize_consts_avx(base2k - lsh)
@@ -200,7 +200,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
_mm256_storeu_si256(cc, carry_256);
@@ -239,7 +239,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let xv: __m256i = _mm256_loadu_si256(xx);
@@ -248,7 +248,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -257,7 +257,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
cc = cc.add(1);
}
} else {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -268,7 +268,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -311,7 +311,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let av: __m256i = _mm256_loadu_si256(aa);
@@ -320,7 +320,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -332,7 +332,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -343,7 +343,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -359,13 +359,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
znx_normalize_first_step_ref(
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
&mut carry[span << 2..],
);
znx_normalize_first_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
}
}
@@ -386,7 +380,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -398,11 +392,11 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(xx, x1);
@@ -414,7 +408,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -423,13 +417,13 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(xx, x1);
@@ -465,7 +459,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
@@ -477,11 +471,11 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(cc, cout);
@@ -492,7 +486,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -501,13 +495,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(cc, cout);
@@ -543,7 +537,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -556,11 +550,11 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(av, mask, sign);
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
let c0: __m256i = get_carry_avx(av, d0, base2k_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(xx, x1);
@@ -573,7 +567,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -582,13 +576,13 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
let c0: __m256i = get_carry_avx(av, d0, base2k_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
_mm256_storeu_si256(xx, x1);
@@ -604,13 +598,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
znx_normalize_middle_step_ref(
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
&mut carry[span << 2..],
);
znx_normalize_middle_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
}
}
@@ -753,13 +741,7 @@ pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a:
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
znx_normalize_final_step_ref(
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
&mut carry[span << 2..],
);
znx_normalize_final_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]);
}
}
@@ -832,8 +814,8 @@ mod tests {
unsafe {
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
let (_, _, base2k_vec, top_mask) = normalize_consts_avx(base2k);
let digit: __m256i = get_carry_avx(x_256, d_256, base2k_vec, top_mask);
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
}
assert_eq!(y0, y1);