Add bivariate convolution

This commit is contained in:
Jean-Philippe Bossuat
2025-10-23 19:00:26 +02:00
parent 9bb6256fc4
commit af1c98c2c4
18 changed files with 454 additions and 26 deletions

View File

@@ -1,4 +1,8 @@
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
use poulpy_hal::{
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution,
};
use crate::FFT64Avx;
cross_backend_test_suite! {
mod vec_znx,
@@ -115,3 +119,9 @@ backend_test_suite! {
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
}
}
#[test]
fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
test_convolution(&module);
}

View File

@@ -4,14 +4,15 @@ use poulpy_hal::{
VecZnxToRef,
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_idft_apply_tmpa,
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy,
vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply,
vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa,
},
};
@@ -121,6 +122,22 @@ unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
a_scale: i64,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale);
}
}
unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_impl<R, A, B>(
_module: &Module<Self>,

View File

@@ -9,4 +9,7 @@ mod vmp;
mod zn;
mod znx;
#[cfg(test)]
mod tests;
pub struct FFT64Ref {}

View File

@@ -0,0 +1,9 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution};
use crate::FFT64Ref;
#[test]
fn test_convolution_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(64);
test_convolution(&module);
}

View File

@@ -4,14 +4,15 @@ use poulpy_hal::{
VecZnxToRef,
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_idft_apply_tmpa,
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy,
vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply,
vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa,
},
};
@@ -111,6 +112,22 @@ unsafe impl VecZnxDftAddImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
a_scale: i64,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale);
}
}
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where