mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add glwe tensoiring
This commit is contained in:
@@ -6,39 +6,19 @@ use crate::{
|
||||
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
|
||||
};
|
||||
|
||||
impl<BE: Backend> Convolution<BE> for Module<BE>
|
||||
impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
}
|
||||
|
||||
pub trait Convolution<BE: Backend>
|
||||
pub trait BivariateTensoring<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
|
||||
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
|
||||
}
|
||||
|
||||
fn bivariate_convolution_full<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
@@ -55,14 +35,48 @@ where
|
||||
assert!(res_cols >= a_cols + b_cols - 1);
|
||||
|
||||
for res_col in 0..res_cols {
|
||||
let a_min: usize = res_col.saturating_sub(b_cols - 1);
|
||||
let a_max: usize = res_col.min(a_cols - 1);
|
||||
self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch);
|
||||
for a_col in a_min + 1..a_max + 1 {
|
||||
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch);
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
}
|
||||
|
||||
for a_col in 0..a_cols {
|
||||
for b_col in 0..b_cols {
|
||||
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
}
|
||||
|
||||
pub trait BivariateConvolution<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
|
||||
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
|
||||
}
|
||||
|
||||
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
|
||||
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
|
||||
@@ -96,7 +110,7 @@ where
|
||||
/// [r03, r13, r23, r33]
|
||||
///
|
||||
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
|
||||
fn bivariate_convolution_single_add<R, A, B>(
|
||||
fn bivariate_convolution_add<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
@@ -123,10 +137,9 @@ where
|
||||
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
|
||||
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn bivariate_convolution_single<R, A, B>(
|
||||
fn bivariate_convolution<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
@@ -142,6 +155,6 @@ where
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch);
|
||||
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
api::{
|
||||
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
|
||||
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
|
||||
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
|
||||
},
|
||||
layouts::{
|
||||
@@ -10,10 +10,10 @@ use crate::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_convolution<M, BE: Backend>(module: &M)
|
||||
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
|
||||
where
|
||||
M: ModuleN
|
||||
+ Convolution<BE>
|
||||
+ BivariateTensoring<BE>
|
||||
+ VecZnxDftAlloc<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyTmpA<BE>
|
||||
@@ -55,7 +55,7 @@ where
|
||||
for mut k in 0..(2 * c_size + 1) as i64 {
|
||||
k -= c_size as i64;
|
||||
|
||||
module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
|
||||
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
|
||||
|
||||
for i in 0..c_cols {
|
||||
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
|
||||
@@ -73,13 +73,13 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
|
||||
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
|
||||
|
||||
assert_eq!(c_want, c_have);
|
||||
}
|
||||
}
|
||||
|
||||
fn convolution_naive<R, A, B, M, BE: Backend>(
|
||||
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
|
||||
module: &M,
|
||||
base2k: usize,
|
||||
k: i64,
|
||||
|
||||
Reference in New Issue
Block a user