Add glwe tensoiring

This commit is contained in:
Pro7ech
2025-10-26 19:03:15 +01:00
parent 6e9cef5ecd
commit 41ca5aafcc
9 changed files with 199 additions and 138 deletions

View File

@@ -6,39 +6,19 @@ use crate::{
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
};
impl<BE: Backend> Convolution<BE> for Module<BE>
impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Self: BivariateConvolution<BE>,
Scratch<BE>: ScratchTakeBasic,
{
}
pub trait Convolution<BE: Backend>
pub trait BivariateTensoring<BE: Backend>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Self: BivariateConvolution<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
fn bivariate_convolution_full<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
@@ -55,14 +35,48 @@ where
assert!(res_cols >= a_cols + b_cols - 1);
for res_col in 0..res_cols {
let a_min: usize = res_col.saturating_sub(b_cols - 1);
let a_max: usize = res_col.min(a_cols - 1);
self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch);
for a_col in a_min + 1..a_max + 1 {
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch);
self.vec_znx_dft_zero(res, res_col);
}
for a_col in 0..a_cols {
for b_col in 0..b_cols {
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
}
}
}
}
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
}
pub trait BivariateConvolution<BE: Backend>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
@@ -96,7 +110,7 @@ where
/// [r03, r13, r23, r33]
///
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
fn bivariate_convolution_single_add<R, A, B>(
fn bivariate_convolution_add<R, A, B>(
&self,
k: i64,
res: &mut R,
@@ -123,10 +137,9 @@ where
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
}
}
fn bivariate_convolution_single<R, A, B>(
fn bivariate_convolution<R, A, B>(
&self,
k: i64,
res: &mut R,
@@ -142,6 +155,6 @@ where
B: VecZnxDftToRef<BE>,
{
self.vec_znx_dft_zero(res, res_col);
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch);
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
}
}

View File

@@ -1,6 +1,6 @@
use crate::{
api::{
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
},
layouts::{
@@ -10,10 +10,10 @@ use crate::{
source::Source,
};
pub fn test_convolution<M, BE: Backend>(module: &M)
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
where
M: ModuleN
+ Convolution<BE>
+ BivariateTensoring<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE>
@@ -55,7 +55,7 @@ where
for mut k in 0..(2 * c_size + 1) as i64 {
k -= c_size as i64;
module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
@@ -73,13 +73,13 @@ where
);
}
convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
assert_eq!(c_want, c_have);
}
}
fn convolution_naive<R, A, B, M, BE: Backend>(
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
module: &M,
base2k: usize,
k: i64,