mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add bivariate convolution
This commit is contained in:
@@ -1,4 +1,8 @@
|
|||||||
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
|
use poulpy_hal::{
|
||||||
|
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::FFT64Avx;
|
||||||
|
|
||||||
cross_backend_test_suite! {
|
cross_backend_test_suite! {
|
||||||
mod vec_znx,
|
mod vec_znx,
|
||||||
@@ -115,3 +119,9 @@ backend_test_suite! {
|
|||||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
|
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convolution_fft64_avx() {
|
||||||
|
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
|
||||||
|
test_convolution(&module);
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ use poulpy_hal::{
|
|||||||
VecZnxToRef,
|
VecZnxToRef,
|
||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
|
||||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
|
||||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
|
||||||
|
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||||
},
|
},
|
||||||
reference::fft64::vec_znx_dft::{
|
reference::fft64::vec_znx_dft::{
|
||||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
|
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy,
|
||||||
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply,
|
||||||
vec_znx_idft_apply_tmpa,
|
vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -121,6 +122,22 @@ unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Avx {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for FFT64Avx {
|
||||||
|
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
|
||||||
|
_module: &Module<Self>,
|
||||||
|
res: &mut R,
|
||||||
|
res_col: usize,
|
||||||
|
a: &A,
|
||||||
|
a_col: usize,
|
||||||
|
a_scale: i64,
|
||||||
|
) where
|
||||||
|
R: VecZnxDftToMut<Self>,
|
||||||
|
A: VecZnxDftToRef<Self>,
|
||||||
|
{
|
||||||
|
vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
|
unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
|
||||||
fn vec_znx_dft_sub_impl<R, A, B>(
|
fn vec_znx_dft_sub_impl<R, A, B>(
|
||||||
_module: &Module<Self>,
|
_module: &Module<Self>,
|
||||||
|
|||||||
@@ -9,4 +9,7 @@ mod vmp;
|
|||||||
mod zn;
|
mod zn;
|
||||||
mod znx;
|
mod znx;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
pub struct FFT64Ref {}
|
pub struct FFT64Ref {}
|
||||||
|
|||||||
9
poulpy-backend/src/cpu_fft64_ref/tests.rs
Normal file
9
poulpy-backend/src/cpu_fft64_ref/tests.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution};
|
||||||
|
|
||||||
|
use crate::FFT64Ref;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convolution_fft64_ref() {
|
||||||
|
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(64);
|
||||||
|
test_convolution(&module);
|
||||||
|
}
|
||||||
@@ -4,14 +4,15 @@ use poulpy_hal::{
|
|||||||
VecZnxToRef,
|
VecZnxToRef,
|
||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
|
||||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
|
||||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
|
||||||
|
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||||
},
|
},
|
||||||
reference::fft64::vec_znx_dft::{
|
reference::fft64::vec_znx_dft::{
|
||||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
|
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_add_scaled_inplace, vec_znx_dft_apply, vec_znx_dft_copy,
|
||||||
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
vec_znx_dft_sub, vec_znx_dft_sub_inplace, vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply,
|
||||||
vec_znx_idft_apply_tmpa,
|
vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -111,6 +112,22 @@ unsafe impl VecZnxDftAddImpl<Self> for FFT64Ref {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for FFT64Ref {
|
||||||
|
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
|
||||||
|
_module: &Module<Self>,
|
||||||
|
res: &mut R,
|
||||||
|
res_col: usize,
|
||||||
|
a: &A,
|
||||||
|
a_col: usize,
|
||||||
|
a_scale: i64,
|
||||||
|
) where
|
||||||
|
R: VecZnxDftToMut<Self>,
|
||||||
|
A: VecZnxDftToRef<Self>,
|
||||||
|
{
|
||||||
|
vec_znx_dft_add_scaled_inplace(res, res_col, a, a_col, a_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Ref {
|
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Ref {
|
||||||
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
|
|||||||
109
poulpy-hal/src/api/convolution.rs
Normal file
109
poulpy-hal/src/api/convolution.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
use crate::{
|
||||||
|
api::{
|
||||||
|
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
|
||||||
|
VecZnxDftBytesOf,
|
||||||
|
},
|
||||||
|
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero},
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<BE: Backend> Convolution<BE> for Module<BE>
|
||||||
|
where
|
||||||
|
Self: Sized
|
||||||
|
+ ModuleN
|
||||||
|
+ SvpPPolAlloc<BE>
|
||||||
|
+ SvpApplyDftToDft<BE>
|
||||||
|
+ SvpPrepare<BE>
|
||||||
|
+ SvpPPolBytesOf
|
||||||
|
+ VecZnxDftBytesOf
|
||||||
|
+ VecZnxDftAddScaledInplace<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeBasic,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Convolution<BE: Backend>
|
||||||
|
where
|
||||||
|
Self: Sized
|
||||||
|
+ ModuleN
|
||||||
|
+ SvpPPolAlloc<BE>
|
||||||
|
+ SvpApplyDftToDft<BE>
|
||||||
|
+ SvpPrepare<BE>
|
||||||
|
+ SvpPPolBytesOf
|
||||||
|
+ VecZnxDftBytesOf
|
||||||
|
+ VecZnxDftAddScaledInplace<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeBasic,
|
||||||
|
{
|
||||||
|
fn convolution_tmp_bytes(&self, res_size: usize) -> usize {
|
||||||
|
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K
|
||||||
|
/// and scales the result by 2^{res_scale * K}
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
|
||||||
|
/// [a01, a11, a21, a31]
|
||||||
|
///
|
||||||
|
/// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
|
||||||
|
/// [b01, b11, b21, b31]
|
||||||
|
///
|
||||||
|
/// If res_scale = 0:
|
||||||
|
/// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
|
||||||
|
/// [r01, r11, r21, r31]
|
||||||
|
/// [r02, r12, r22, r32]
|
||||||
|
/// [r03, r13, r23, r33]
|
||||||
|
/// [r04, r14, r24, r34]
|
||||||
|
///
|
||||||
|
/// If res_scale = 1:
|
||||||
|
/// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
|
||||||
|
/// [r02, r12, r22, r32]
|
||||||
|
/// [r03, r13, r23, r33]
|
||||||
|
/// [r04, r14, r24, r34]
|
||||||
|
/// [r05, r15, r25, r35]
|
||||||
|
///
|
||||||
|
/// If res_scale = -1:
|
||||||
|
/// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
|
||||||
|
/// [ 0, 0, 0, 0]
|
||||||
|
/// [r01, r11, r21, r31]
|
||||||
|
/// [r02, r12, r22, r32]
|
||||||
|
/// [r03, r13, r23, r33]
|
||||||
|
///
|
||||||
|
/// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension.
|
||||||
|
fn convolution<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<BE>,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
B: VecZnxDftToRef<BE>,
|
||||||
|
{
|
||||||
|
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
|
||||||
|
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
|
||||||
|
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
|
||||||
|
|
||||||
|
assert!(res.cols() >= a.cols() + b.cols() - 1);
|
||||||
|
|
||||||
|
res.zero();
|
||||||
|
|
||||||
|
let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
|
||||||
|
let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size());
|
||||||
|
|
||||||
|
for a_col in 0..a.cols() {
|
||||||
|
for a_limb in 0..a.size() {
|
||||||
|
// Prepares the j-th limb of the i-th col of A
|
||||||
|
self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
|
||||||
|
|
||||||
|
for b_col in 0..b.cols() {
|
||||||
|
// Multiplies with the i-th col of B
|
||||||
|
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
|
||||||
|
|
||||||
|
// Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K}
|
||||||
|
self.vec_znx_dft_add_scaled_inplace(
|
||||||
|
res,
|
||||||
|
a_col + b_col,
|
||||||
|
&res_tmp,
|
||||||
|
0,
|
||||||
|
-(1 + a_limb as i64) + res_scale,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
mod convolution;
|
||||||
mod module;
|
mod module;
|
||||||
mod scratch;
|
mod scratch;
|
||||||
mod svp_ppol;
|
mod svp_ppol;
|
||||||
@@ -7,6 +8,7 @@ mod vec_znx_dft;
|
|||||||
mod vmp_pmat;
|
mod vmp_pmat;
|
||||||
mod zn;
|
mod zn;
|
||||||
|
|
||||||
|
pub use convolution::*;
|
||||||
pub use module::*;
|
pub use module::*;
|
||||||
pub use scratch::*;
|
pub use scratch::*;
|
||||||
pub use svp_ppol::*;
|
pub use svp_ppol::*;
|
||||||
|
|||||||
@@ -164,10 +164,10 @@ pub trait VecZnxBigNormalizeTmpBytes {
|
|||||||
pub trait VecZnxBigNormalize<B: Backend> {
|
pub trait VecZnxBigNormalize<B: Backend> {
|
||||||
fn vec_znx_big_normalize<R, A>(
|
fn vec_znx_big_normalize<R, A>(
|
||||||
&self,
|
&self,
|
||||||
res_basek: usize,
|
res_base2k: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a_basek: usize,
|
a_base2k: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ pub trait VecZnxDftAddInplace<B: Backend> {
|
|||||||
A: VecZnxDftToRef<B>;
|
A: VecZnxDftToRef<B>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxDftAddScaledInplace<B: Backend> {
|
||||||
|
fn vec_znx_dft_add_scaled_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<B>,
|
||||||
|
A: VecZnxDftToRef<B>;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait VecZnxDftSub<B: Backend> {
|
pub trait VecZnxDftSub<B: Backend> {
|
||||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||||
where
|
where
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
api::{
|
api::{
|
||||||
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxDftFromBytes,
|
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAddScaledInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf,
|
||||||
VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume,
|
VecZnxDftCopy, VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero,
|
||||||
VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
|
VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
|
||||||
},
|
},
|
||||||
layouts::{
|
layouts::{
|
||||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||||
VecZnxToRef,
|
VecZnxToRef,
|
||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl,
|
||||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
VecZnxDftApplyImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl,
|
||||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl,
|
||||||
|
VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -129,6 +130,19 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B> VecZnxDftAddScaledInplace<B> for Module<B>
|
||||||
|
where
|
||||||
|
B: Backend + VecZnxDftAddScaledInplaceImpl<B>,
|
||||||
|
{
|
||||||
|
fn vec_znx_dft_add_scaled_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<B>,
|
||||||
|
A: VecZnxDftToRef<B>,
|
||||||
|
{
|
||||||
|
B::vec_znx_dft_add_scaled_inplace_impl(self, res, res_col, a, a_col, a_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<B> VecZnxDftSub<B> for Module<B>
|
impl<B> VecZnxDftSub<B> for Module<B>
|
||||||
where
|
where
|
||||||
B: Backend + VecZnxDftSubImpl<B>,
|
B: Backend + VecZnxDftSubImpl<B>,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use std::{
|
|||||||
ptr::NonNull,
|
ptr::NonNull,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use bytemuck::Pod;
|
||||||
use rand_distr::num_traits::Zero;
|
use rand_distr::num_traits::Zero;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -13,8 +14,8 @@ use crate::{
|
|||||||
|
|
||||||
#[allow(clippy::missing_safety_doc)]
|
#[allow(clippy::missing_safety_doc)]
|
||||||
pub trait Backend: Sized {
|
pub trait Backend: Sized {
|
||||||
type ScalarBig: Copy + Zero + Display + Debug;
|
type ScalarBig: Copy + Zero + Display + Debug + Pod;
|
||||||
type ScalarPrep: Copy + Zero + Display + Debug;
|
type ScalarPrep: Copy + Zero + Display + Debug + Pod;
|
||||||
type Handle: 'static;
|
type Handle: 'static;
|
||||||
fn layout_prep_word_count() -> usize;
|
fn layout_prep_word_count() -> usize;
|
||||||
fn layout_big_word_count() -> usize;
|
fn layout_big_word_count() -> usize;
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ use std::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
alloc_aligned,
|
alloc_aligned,
|
||||||
layouts::{
|
layouts::{
|
||||||
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
|
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ScalarZnx, ToOwnedDeep, WriterTo,
|
||||||
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||||
},
|
},
|
||||||
source::Source,
|
source::Source,
|
||||||
};
|
};
|
||||||
@@ -25,6 +25,26 @@ pub struct VecZnx<D: Data> {
|
|||||||
pub max_size: usize,
|
pub max_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D: DataRef> VecZnx<D> {
|
||||||
|
pub fn as_scalar_znx_ref(&self, col: usize, limb: usize) -> ScalarZnx<&[u8]> {
|
||||||
|
ScalarZnx {
|
||||||
|
data: bytemuck::cast_slice(self.at(col, limb)),
|
||||||
|
n: self.n,
|
||||||
|
cols: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: DataMut> VecZnx<D> {
|
||||||
|
pub fn as_scalar_znx_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<&mut [u8]> {
|
||||||
|
ScalarZnx {
|
||||||
|
n: self.n,
|
||||||
|
cols: 1,
|
||||||
|
data: bytemuck::cast_slice_mut(self.at_mut(col, limb)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<D: Data + Default> Default for VecZnx<D> {
|
impl<D: Data + Default> Default for VecZnx<D> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use crate::{
|
|||||||
layouts::{Backend, Data, DataMut, DataRef},
|
layouts::{Backend, Data, DataMut, DataRef},
|
||||||
source::Source,
|
source::Source,
|
||||||
};
|
};
|
||||||
|
use bytemuck::Pod;
|
||||||
use rand_distr::num_traits::Zero;
|
use rand_distr::num_traits::Zero;
|
||||||
|
|
||||||
pub trait ZnxInfos {
|
pub trait ZnxInfos {
|
||||||
@@ -50,7 +51,7 @@ pub trait DataViewMut: DataView {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
|
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
|
||||||
type Scalar: Copy + Zero + Display + Debug;
|
type Scalar: Copy + Zero + Display + Debug + Pod;
|
||||||
|
|
||||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||||
fn as_ptr(&self) -> *const Self::Scalar {
|
fn as_ptr(&self) -> *const Self::Scalar {
|
||||||
|
|||||||
@@ -103,6 +103,23 @@ pub unsafe trait VecZnxDftAddImpl<B: Backend> {
|
|||||||
D: VecZnxDftToRef<B>;
|
D: VecZnxDftToRef<B>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||||
|
/// * See TODO reference implementation.
|
||||||
|
/// * See [crate::api::VecZnxDftAddScaledInplace] for corresponding public API.
|
||||||
|
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||||
|
pub unsafe trait VecZnxDftAddScaledInplaceImpl<B: Backend> {
|
||||||
|
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
|
||||||
|
module: &Module<B>,
|
||||||
|
res: &mut R,
|
||||||
|
res_col: usize,
|
||||||
|
a: &A,
|
||||||
|
a_col: usize,
|
||||||
|
a_scale: i64,
|
||||||
|
) where
|
||||||
|
R: VecZnxDftToMut<B>,
|
||||||
|
A: VecZnxDftToRef<B>;
|
||||||
|
}
|
||||||
|
|
||||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||||
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
|
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
|
||||||
/// * See [crate::api::VecZnxDftAddInplace] for corresponding public API.
|
/// * See [crate::api::VecZnxDftAddInplace] for corresponding public API.
|
||||||
|
|||||||
@@ -92,6 +92,44 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// res = res + a * 2^{a_scale * base2k}.
|
||||||
|
pub fn vec_znx_dft_add_scaled_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64)
|
||||||
|
where
|
||||||
|
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
|
||||||
|
R: VecZnxDftToMut<BE>,
|
||||||
|
A: VecZnxDftToRef<BE>,
|
||||||
|
{
|
||||||
|
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||||
|
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), res.n());
|
||||||
|
}
|
||||||
|
|
||||||
|
let res_size: usize = res.size();
|
||||||
|
let a_size: usize = a.size();
|
||||||
|
|
||||||
|
if a_scale > 0 {
|
||||||
|
let shift: usize = (a_scale as usize).min(a_size);
|
||||||
|
let sum_size: usize = a_size.min(res_size).saturating_sub(shift);
|
||||||
|
for j in 0..sum_size {
|
||||||
|
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j + shift));
|
||||||
|
}
|
||||||
|
} else if a_scale < 0 {
|
||||||
|
let shift: usize = (a_scale.unsigned_abs() as usize).min(res_size);
|
||||||
|
let sum_size: usize = a_size.min(res_size).saturating_sub(shift);
|
||||||
|
for j in 0..sum_size {
|
||||||
|
BE::reim_add_inplace(res.at_mut(res_col, j + shift), a.at(a_col, j));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let sum_size: usize = a_size.min(res_size);
|
||||||
|
for j in 0..sum_size {
|
||||||
|
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
|
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
|
||||||
|
|||||||
0
poulpy-hal/src/reference/vec_znx/convolution.rs
Normal file
0
poulpy-hal/src/reference/vec_znx/convolution.rs
Normal file
162
poulpy-hal/src/test_suite/convolution.rs
Normal file
162
poulpy-hal/src/test_suite/convolution.rs
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
use crate::{
|
||||||
|
api::{
|
||||||
|
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigNormalize,
|
||||||
|
VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeInplace,
|
||||||
|
},
|
||||||
|
layouts::{
|
||||||
|
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
|
||||||
|
ZnxViewMut, ZnxZero,
|
||||||
|
},
|
||||||
|
source::Source,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn test_convolution<M, BE: Backend>(module: &M)
|
||||||
|
where
|
||||||
|
M: ModuleN
|
||||||
|
+ Convolution<BE>
|
||||||
|
+ VecZnxDftAlloc<BE>
|
||||||
|
+ VecZnxDftApply<BE>
|
||||||
|
+ VecZnxIdftApplyConsume<BE>
|
||||||
|
+ VecZnxBigNormalize<BE>
|
||||||
|
+ VecZnxNormalizeInplace<BE>,
|
||||||
|
Scratch<BE>: ScratchTakeBasic,
|
||||||
|
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||||
|
{
|
||||||
|
let mut source: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
let base2k: usize = 12;
|
||||||
|
|
||||||
|
for a_cols in 1..3 {
|
||||||
|
for b_cols in 1..3 {
|
||||||
|
for a_size in 1..5 {
|
||||||
|
for b_size in 1..5 {
|
||||||
|
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
|
||||||
|
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
|
||||||
|
|
||||||
|
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols + b_cols - 1, b_size + a_size);
|
||||||
|
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_want.cols(), c_want.size());
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(c_want.size()));
|
||||||
|
|
||||||
|
a.fill_uniform(base2k, &mut source);
|
||||||
|
b.fill_uniform(base2k, &mut source);
|
||||||
|
|
||||||
|
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b.cols(), b.size());
|
||||||
|
|
||||||
|
for i in 0..b.cols() {
|
||||||
|
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for mut res_scale in 0..2 * c_want.size() as i64 + 1 {
|
||||||
|
res_scale -= c_want.size() as i64;
|
||||||
|
|
||||||
|
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_have.cols(), c_have.size());
|
||||||
|
module.convolution(&mut c_have_dft, res_scale, &a, &b_dft, scratch.borrow());
|
||||||
|
|
||||||
|
let c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_idft_apply_consume(c_have_dft);
|
||||||
|
|
||||||
|
for i in 0..c_have.cols() {
|
||||||
|
module.vec_znx_big_normalize(
|
||||||
|
base2k,
|
||||||
|
&mut c_have,
|
||||||
|
i,
|
||||||
|
base2k,
|
||||||
|
&c_have_big,
|
||||||
|
i,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_naive(
|
||||||
|
module,
|
||||||
|
base2k,
|
||||||
|
&mut c_want,
|
||||||
|
res_scale,
|
||||||
|
&a,
|
||||||
|
&b,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(c_want, c_have);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convolution_naive<R, A, B, M, BE: Backend>(
|
||||||
|
module: &M,
|
||||||
|
base2k: usize,
|
||||||
|
res: &mut R,
|
||||||
|
res_scale: i64,
|
||||||
|
a: &A,
|
||||||
|
b: &B,
|
||||||
|
scratch: &mut Scratch<BE>,
|
||||||
|
) where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
B: VecZnxToRef,
|
||||||
|
M: VecZnxNormalizeInplace<BE>,
|
||||||
|
Scratch<BE>: TakeSlice,
|
||||||
|
{
|
||||||
|
let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
|
||||||
|
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||||
|
let b: &VecZnx<&[u8]> = &b.to_ref();
|
||||||
|
|
||||||
|
assert!(res.cols() >= a.cols() + b.cols() - 1);
|
||||||
|
|
||||||
|
res.zero();
|
||||||
|
|
||||||
|
for a_col in 0..a.cols() {
|
||||||
|
for a_limb in 0..a.size() {
|
||||||
|
for b_col in 0..b.cols() {
|
||||||
|
for b_limb in 0..b.size() {
|
||||||
|
let res_scale_abs = res_scale.unsigned_abs() as usize;
|
||||||
|
|
||||||
|
let mut res_limb: usize = a_limb + b_limb + 1;
|
||||||
|
|
||||||
|
if res_scale <= 0 {
|
||||||
|
res_limb += res_scale_abs;
|
||||||
|
|
||||||
|
if res_limb < res.size() {
|
||||||
|
negacyclic_convolution_naive_add(
|
||||||
|
res.at_mut(a_col + b_col, res_limb),
|
||||||
|
a.at(a_col, a_limb),
|
||||||
|
b.at(b_col, b_limb),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else if res_limb >= res_scale_abs {
|
||||||
|
res_limb -= res_scale_abs;
|
||||||
|
|
||||||
|
if res_limb < res.size() {
|
||||||
|
negacyclic_convolution_naive_add(
|
||||||
|
res.at_mut(a_col + b_col, res_limb),
|
||||||
|
a.at(a_col, a_limb),
|
||||||
|
b.at(b_col, b_limb),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..res.cols() {
|
||||||
|
module.vec_znx_normalize_inplace(base2k, res, i, scratch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||||
|
let n: usize = res.len();
|
||||||
|
for i in 0..n {
|
||||||
|
let ai: i64 = a[i];
|
||||||
|
let lim: usize = n - i;
|
||||||
|
for j in 0..lim {
|
||||||
|
res[i + j] += ai * b[j];
|
||||||
|
}
|
||||||
|
for j in lim..n {
|
||||||
|
res[i + j - n] -= ai * b[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod convolution;
|
||||||
pub mod serialization;
|
pub mod serialization;
|
||||||
pub mod svp;
|
pub mod svp;
|
||||||
pub mod vec_znx;
|
pub mod vec_znx;
|
||||||
|
|||||||
Reference in New Issue
Block a user