From af1c98c2c42451eec1a390cd99e086476356febd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 23 Oct 2025 19:00:26 +0200 Subject: [PATCH] Add bivariate convolution --- poulpy-backend/src/cpu_fft64_avx/tests.rs | 12 +- .../src/cpu_fft64_avx/vec_znx_dft.rs | 29 +++- poulpy-backend/src/cpu_fft64_ref/mod.rs | 3 + poulpy-backend/src/cpu_fft64_ref/tests.rs | 9 + .../src/cpu_fft64_ref/vec_znx_dft.rs | 29 +++- poulpy-hal/src/api/convolution.rs | 109 ++++++++++++ poulpy-hal/src/api/mod.rs | 2 + poulpy-hal/src/api/vec_znx_big.rs | 4 +- poulpy-hal/src/api/vec_znx_dft.rs | 7 + poulpy-hal/src/delegates/vec_znx_dft.rs | 26 ++- poulpy-hal/src/layouts/module.rs | 5 +- poulpy-hal/src/layouts/vec_znx.rs | 24 ++- poulpy-hal/src/layouts/znx_base.rs | 3 +- poulpy-hal/src/oep/vec_znx_dft.rs | 17 ++ poulpy-hal/src/reference/fft64/vec_znx_dft.rs | 38 ++++ .../src/reference/vec_znx/convolution.rs | 0 poulpy-hal/src/test_suite/convolution.rs | 162 ++++++++++++++++++ poulpy-hal/src/test_suite/mod.rs | 1 + 18 files changed, 454 insertions(+), 26 deletions(-) create mode 100644 poulpy-backend/src/cpu_fft64_ref/tests.rs create mode 100644 poulpy-hal/src/api/convolution.rs create mode 100644 poulpy-hal/src/reference/vec_znx/convolution.rs create mode 100644 poulpy-hal/src/test_suite/convolution.rs diff --git a/poulpy-backend/src/cpu_fft64_avx/tests.rs b/poulpy-backend/src/cpu_fft64_avx/tests.rs index d57f6c4..35ae6d3 100644 --- a/poulpy-backend/src/cpu_fft64_avx/tests.rs +++ b/poulpy-backend/src/cpu_fft64_avx/tests.rs @@ -1,4 +1,8 @@ -use poulpy_hal::{backend_test_suite, cross_backend_test_suite}; +use poulpy_hal::{ + api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution, +}; + +use crate::FFT64Avx; cross_backend_test_suite! { mod vec_znx, @@ -115,3 +119,9 @@ backend_test_suite! { test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal, } } + +#[test] +fn test_convolution_fft64_avx() { + let module: Module = Module::::new(64); + test_convolution(&module); +} diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs index 063ee26..57ffc6f 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -4,14 +4,15 @@ use poulpy_hal::{ VecZnxToRef, }, oep::{ - VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, - VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, + VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, + VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, + VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, reference::fft64::vec_znx_dft::{ - vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace, - vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, - vec_znx_idft_apply_tmpa, + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy, + vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, + vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa, }, }; @@ -121,6 +122,22 @@ unsafe impl VecZnxDftAddInplaceImpl for FFT64Avx { } } +unsafe impl VecZnxDftAddScaledInplaceImpl for FFT64Avx { + fn vec_znx_dft_add_scaled_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + a_scale: i64, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale); + } +} + unsafe impl VecZnxDftSubImpl for FFT64Avx { fn vec_znx_dft_sub_impl( _module: &Module, diff --git a/poulpy-backend/src/cpu_fft64_ref/mod.rs b/poulpy-backend/src/cpu_fft64_ref/mod.rs index 9f1be05..360c315 100644 --- a/poulpy-backend/src/cpu_fft64_ref/mod.rs +++ b/poulpy-backend/src/cpu_fft64_ref/mod.rs @@ -9,4 +9,7 @@ mod vmp; mod zn; mod znx; +#[cfg(test)] +mod tests; + pub struct FFT64Ref {} diff --git a/poulpy-backend/src/cpu_fft64_ref/tests.rs b/poulpy-backend/src/cpu_fft64_ref/tests.rs new file mode 100644 index 0000000..3f824c3 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/tests.rs @@ -0,0 +1,9 @@ +use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution}; + +use crate::FFT64Ref; + +#[test] +fn test_convolution_fft64_ref() { + let module: Module = Module::::new(64); + test_convolution(&module); +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs index a2a743d..5ad6400 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -4,14 +4,15 @@ use poulpy_hal::{ VecZnxToRef, }, oep::{ - VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, - VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, + VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, + VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, + VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, reference::fft64::vec_znx_dft::{ - vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace, - vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, - vec_znx_idft_apply_tmpa, + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy, + vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, + vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa, }, }; @@ -111,6 +112,22 @@ unsafe impl VecZnxDftAddImpl for FFT64Ref { } } +unsafe impl VecZnxDftAddScaledInplaceImpl for FFT64Ref { + fn vec_znx_dft_add_scaled_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + a_scale: i64, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale); + } +} + unsafe impl VecZnxDftAddInplaceImpl for FFT64Ref { fn vec_znx_dft_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs new file mode 100644 index 0000000..2f32de2 --- /dev/null +++ b/poulpy-hal/src/api/convolution.rs @@ -0,0 +1,109 @@ +use crate::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace, + VecZnxDftBytesOf, + }, + layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero}, +}; + +impl Convolution for Module +where + Self: Sized + + ModuleN + + SvpPPolAlloc + + SvpApplyDftToDft + + SvpPrepare + + SvpPPolBytesOf + + VecZnxDftBytesOf + + VecZnxDftAddScaledInplace, + Scratch: ScratchTakeBasic, +{ +} + +pub trait Convolution +where + Self: Sized + + ModuleN + + SvpPPolAlloc + + SvpApplyDftToDft + + SvpPrepare + + SvpPPolBytesOf + + VecZnxDftBytesOf + + VecZnxDftAddScaledInplace, + Scratch: ScratchTakeBasic, +{ + fn convolution_tmp_bytes(&self, res_size: usize) -> usize { + self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size) + } + + /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K + /// and scales the result by 2^{res_scale * K} + /// + /// # Example + /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ... + /// [a01, a11, a21, a31] + /// + /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ... + /// [b01, b11, b21, b31] + /// + /// If res_scale = 0: + /// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ... + /// [r01, r11, r21, r31] + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// [r04, r14, r24, r34] + /// + /// If res_scale = 1: + /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ... + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// [r04, r14, r24, r34] + /// [r05, r15, r25, r35] + /// + /// If res_scale = -1: + /// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ... + /// [ 0, 0, 0, 0] + /// [r01, r11, r21, r31] + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// + /// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension. + fn convolution(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + B: VecZnxDftToRef, + { + let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); + let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); + let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); + + assert!(res.cols() >= a.cols() + b.cols() - 1); + + res.zero(); + + let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1); + let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); + + for a_col in 0..a.cols() { + for a_limb in 0..a.size() { + // Prepares the j-th limb of the i-th col of A + self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0); + + for b_col in 0..b.cols() { + // Multiplies with the i-th col of B + self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); + + // Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K} + self.vec_znx_dft_add_scaled_inplace( + res, + a_col + b_col, + &res_tmp, + 0, + -(1 + a_limb as i64) + res_scale, + ); + } + } + } + } +} diff --git a/poulpy-hal/src/api/mod.rs b/poulpy-hal/src/api/mod.rs index dac0def..b024a94 100644 --- a/poulpy-hal/src/api/mod.rs +++ b/poulpy-hal/src/api/mod.rs @@ -1,3 +1,4 @@ +mod convolution; mod module; mod scratch; mod svp_ppol; @@ -7,6 +8,7 @@ mod vec_znx_dft; mod vmp_pmat; mod zn; +pub use convolution::*; pub use module::*; pub use scratch::*; pub use svp_ppol::*; diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 8cb5105..2cf9bba 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -164,10 +164,10 @@ pub trait VecZnxBigNormalizeTmpBytes { pub trait VecZnxBigNormalize { fn vec_znx_big_normalize( &self, - res_basek: usize, + res_base2k: usize, res: &mut R, res_col: usize, - a_basek: usize, + a_base2k: usize, a: &A, a_col: usize, scratch: &mut Scratch, diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 3a003a9..61396c4 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -60,6 +60,13 @@ pub trait VecZnxDftAddInplace { A: VecZnxDftToRef; } +pub trait VecZnxDftAddScaledInplace { + fn vec_znx_dft_add_scaled_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + pub trait VecZnxDftSub { fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) where diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 16a583f..7dfb25f 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,17 +1,18 @@ use crate::{ api::{ - VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxDftFromBytes, - VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAddScaledInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, + VecZnxDftCopy, VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, + VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{ Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, }, oep::{ - VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, - VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, + VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, + VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, + VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, }; @@ -129,6 +130,19 @@ where } } +impl VecZnxDftAddScaledInplace for Module +where + B: Backend + VecZnxDftAddScaledInplaceImpl, +{ + fn vec_znx_dft_add_scaled_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_add_scaled_inplace_impl(self, res, res_col, a, a_col, a_scale); + } +} + impl VecZnxDftSub for Module where B: Backend + VecZnxDftSubImpl, diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 0556a6f..54e3ffa 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -4,6 +4,7 @@ use std::{ ptr::NonNull, }; +use bytemuck::Pod; use rand_distr::num_traits::Zero; use crate::{ @@ -13,8 +14,8 @@ use crate::{ #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { - type ScalarBig: Copy + Zero + Display + Debug; - type ScalarPrep: Copy + Zero + Display + Debug; + type ScalarBig: Copy + Zero + Display + Debug + Pod; + type ScalarPrep: Copy + Zero + Display + Debug + Pod; type Handle: 'static; fn layout_prep_word_count() -> usize; fn layout_big_word_count() -> usize; diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index c084934..1435243 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -6,8 +6,8 @@ use std::{ use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ScalarZnx, ToOwnedDeep, WriterTo, + ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; @@ -25,6 +25,26 @@ pub struct VecZnx { pub max_size: usize, } +impl VecZnx { + pub fn as_scalar_znx_ref(&self, col: usize, limb: usize) -> ScalarZnx<&[u8]> { + ScalarZnx { + data: bytemuck::cast_slice(self.at(col, limb)), + n: self.n, + cols: 1, + } + } +} + +impl VecZnx { + pub fn as_scalar_znx_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<&mut [u8]> { + ScalarZnx { + n: self.n, + cols: 1, + data: bytemuck::cast_slice_mut(self.at_mut(col, limb)), + } + } +} + impl Default for VecZnx { fn default() -> Self { Self { diff --git a/poulpy-hal/src/layouts/znx_base.rs b/poulpy-hal/src/layouts/znx_base.rs index 45e30d4..a2c5dd3 100644 --- a/poulpy-hal/src/layouts/znx_base.rs +++ b/poulpy-hal/src/layouts/znx_base.rs @@ -4,6 +4,7 @@ use crate::{ layouts::{Backend, Data, DataMut, DataRef}, source::Source, }; +use bytemuck::Pod; use rand_distr::num_traits::Zero; pub trait ZnxInfos { @@ -50,7 +51,7 @@ pub trait DataViewMut: DataView { } pub trait ZnxView: ZnxInfos + DataView { - type Scalar: Copy + Zero + Display + Debug; + type Scalar: Copy + Zero + Display + Debug + Pod; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index 0f9288b..f561084 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -103,6 +103,23 @@ pub unsafe trait VecZnxDftAddImpl { D: VecZnxDftToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO reference implementation. +/// * See [crate::api::VecZnxDftAddScaledInplace] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxDftAddScaledInplaceImpl { + fn vec_znx_dft_add_scaled_inplace_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + a_scale: i64, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. /// * See [crate::api::VecZnxDftAddInplace] for corresponding public API. diff --git a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs index 4bb086d..e8d12e6 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs @@ -92,6 +92,44 @@ where } } +/// res = res + a * 2^{a_scale * base2k}. +pub fn vec_znx_dft_add_scaled_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64) +where + BE: Backend + ReimAddInplace, + R: VecZnxDftToMut, + A: VecZnxDftToRef, +{ + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + if a_scale > 0 { + let shift: usize = (a_scale as usize).min(a_size); + let sum_size: usize = a_size.min(res_size).saturating_sub(shift); + for j in 0..sum_size { + BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j + shift)); + } + } else if a_scale < 0 { + let shift: usize = (a_scale.unsigned_abs() as usize).min(res_size); + let sum_size: usize = a_size.min(res_size).saturating_sub(shift); + for j in 0..sum_size { + BE::reim_add_inplace(res.at_mut(res_col, j + shift), a.at(a_col, j)); + } + } else { + let sum_size: usize = a_size.min(res_size); + for j in 0..sum_size { + BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } + } +} + pub fn vec_znx_dft_copy(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where BE: Backend + ReimCopy + ReimZero, diff --git a/poulpy-hal/src/reference/vec_znx/convolution.rs b/poulpy-hal/src/reference/vec_znx/convolution.rs new file mode 100644 index 0000000..e69de29 diff --git a/poulpy-hal/src/test_suite/convolution.rs b/poulpy-hal/src/test_suite/convolution.rs new file mode 100644 index 0000000..8f4c71c --- /dev/null +++ b/poulpy-hal/src/test_suite/convolution.rs @@ -0,0 +1,162 @@ +use crate::{ + api::{ + Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigNormalize, + VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeInplace, + }, + layouts::{ + Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, + ZnxViewMut, ZnxZero, + }, + source::Source, +}; + +pub fn test_convolution(module: &M) +where + M: ModuleN + + Convolution + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxNormalizeInplace, + Scratch: ScratchTakeBasic, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + for a_cols in 1..3 { + for b_cols in 1..3 { + for a_size in 1..5 { + for b_size in 1..5 { + let mut a: VecZnx> = VecZnx::alloc(module.n(), a_cols, a_size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), b_cols, b_size); + + let mut c_want: VecZnx> = VecZnx::alloc(module.n(), a_cols + b_cols - 1, b_size + a_size); + let mut c_have: VecZnx> = VecZnx::alloc(module.n(), c_want.cols(), c_want.size()); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.convolution_tmp_bytes(c_want.size())); + + a.fill_uniform(base2k, &mut source); + b.fill_uniform(base2k, &mut source); + + let mut b_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(b.cols(), b.size()); + + for i in 0..b.cols() { + module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i); + } + + for mut res_scale in 0..2 * c_want.size() as i64 + 1 { + res_scale -= c_want.size() as i64; + + let mut c_have_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(c_have.cols(), c_have.size()); + module.convolution(&mut c_have_dft, res_scale, &a, &b_dft, scratch.borrow()); + + let c_have_big: VecZnxBig, BE> = module.vec_znx_idft_apply_consume(c_have_dft); + + for i in 0..c_have.cols() { + module.vec_znx_big_normalize( + base2k, + &mut c_have, + i, + base2k, + &c_have_big, + i, + scratch.borrow(), + ); + } + + convolution_naive( + module, + base2k, + &mut c_want, + res_scale, + &a, + &b, + scratch.borrow(), + ); + + assert_eq!(c_want, c_have); + } + } + } + } + } +} + +fn convolution_naive( + module: &M, + base2k: usize, + res: &mut R, + res_scale: i64, + a: &A, + b: &B, + scratch: &mut Scratch, +) where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + M: VecZnxNormalizeInplace, + Scratch: TakeSlice, +{ + let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let b: &VecZnx<&[u8]> = &b.to_ref(); + + assert!(res.cols() >= a.cols() + b.cols() - 1); + + res.zero(); + + for a_col in 0..a.cols() { + for a_limb in 0..a.size() { + for b_col in 0..b.cols() { + for b_limb in 0..b.size() { + let res_scale_abs = res_scale.unsigned_abs() as usize; + + let mut res_limb: usize = a_limb + b_limb + 1; + + if res_scale <= 0 { + res_limb += res_scale_abs; + + if res_limb < res.size() { + negacyclic_convolution_naive_add( + res.at_mut(a_col + b_col, res_limb), + a.at(a_col, a_limb), + b.at(b_col, b_limb), + ); + } + } else if res_limb >= res_scale_abs { + res_limb -= res_scale_abs; + + if res_limb < res.size() { + negacyclic_convolution_naive_add( + res.at_mut(a_col + b_col, res_limb), + a.at(a_col, a_limb), + b.at(b_col, b_limb), + ); + } + } + } + } + } + } + + for i in 0..res.cols() { + module.vec_znx_normalize_inplace(base2k, res, i, scratch); + } +} + +fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) { + let n: usize = res.len(); + for i in 0..n { + let ai: i64 = a[i]; + let lim: usize = n - i; + for j in 0..lim { + res[i + j] += ai * b[j]; + } + for j in lim..n { + res[i + j - n] -= ai * b[j]; + } + } +} diff --git a/poulpy-hal/src/test_suite/mod.rs b/poulpy-hal/src/test_suite/mod.rs index f31c856..70943c7 100644 --- a/poulpy-hal/src/test_suite/mod.rs +++ b/poulpy-hal/src/test_suite/mod.rs @@ -1,3 +1,4 @@ +pub mod convolution; pub mod serialization; pub mod svp; pub mod vec_znx;