This commit is contained in:
Jean-Philippe Bossuat
2025-10-26 16:32:22 +01:00
parent 98208d5e67
commit 881483d1bb
11 changed files with 173 additions and 140 deletions

View File

@@ -1,7 +1,7 @@
use crate::{
api::{
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigNormalize,
VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeInplace,
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
},
layouts::{
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
@@ -16,9 +16,10 @@ where
+ Convolution<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalizeInplace<BE>,
+ VecZnxNormalizeInplace<BE>
+ VecZnxBigAlloc<BE>,
Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
@@ -26,70 +27,63 @@ where
let base2k: usize = 12;
for a_cols in 1..3 {
for b_cols in 1..3 {
for a_size in 1..5 {
for b_size in 1..5 {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
let a_cols: usize = 3;
let b_cols: usize = 3;
let a_size: usize = 3;
let b_size: usize = 3;
let c_cols: usize = a_cols + b_cols - 1;
let c_size: usize = a_size + b_size;
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols + b_cols - 1, b_size + a_size);
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_want.cols(), c_want.size());
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(c_want.size()));
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_cols, c_size);
let mut c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(c_cols, c_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size));
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b.cols(), b.size());
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
for i in 0..b.cols() {
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
}
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b_cols, b_size);
for i in 0..b.cols() {
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
}
for mut res_scale in 0..2 * c_want.size() as i64 + 1 {
res_scale -= c_want.size() as i64;
for mut k in 0..(2 * c_size + 1) as i64 {
k -= c_size as i64;
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_have.cols(), c_have.size());
module.convolution(&mut c_have_dft, res_scale, &a, &b_dft, scratch.borrow());
module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
let c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_idft_apply_consume(c_have_dft);
for i in 0..c_have.cols() {
module.vec_znx_big_normalize(
base2k,
&mut c_have,
i,
base2k,
&c_have_big,
i,
scratch.borrow(),
);
}
convolution_naive(
module,
base2k,
&mut c_want,
res_scale,
&a,
&b,
scratch.borrow(),
);
assert_eq!(c_want, c_have);
}
}
}
for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
}
for i in 0..c_cols {
module.vec_znx_big_normalize(
base2k,
&mut c_have,
i,
base2k,
&c_have_big,
i,
scratch.borrow(),
);
}
convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
assert_eq!(c_want, c_have);
}
}
fn convolution_naive<R, A, B, M, BE: Backend>(
module: &M,
base2k: usize,
k: i64,
res: &mut R,
res_scale: i64,
a: &A,
b: &B,
scratch: &mut Scratch<BE>,
@@ -112,11 +106,11 @@ fn convolution_naive<R, A, B, M, BE: Backend>(
for a_limb in 0..a.size() {
for b_col in 0..b.cols() {
for b_limb in 0..b.size() {
let res_scale_abs = res_scale.unsigned_abs() as usize;
let res_scale_abs = k.unsigned_abs() as usize;
let mut res_limb: usize = a_limb + b_limb + 1;
if res_scale <= 0 {
if k <= 0 {
res_limb += res_scale_abs;
if res_limb < res.size() {