Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -15,16 +15,16 @@ use crate::{
pub fn bench_svp_prepare<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_prepare::{}", label);
let group_name: String = format!("svp_prepare::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(log_n: usize) -> impl FnMut()
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let module: Module<B> = Module::<B>::new(1 << log_n);
@@ -53,16 +53,16 @@ where
pub fn bench_svp_apply_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft::{}", label);
let group_name: String = format!("svp_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -100,16 +100,16 @@ where
pub fn bench_svp_apply_dft_to_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -147,16 +147,16 @@ where
pub fn bench_svp_apply_dft_to_dft_add<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft_add::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -194,16 +194,16 @@ where
pub fn bench_svp_apply_dft_to_dft_inplace<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];

View File

@@ -8,7 +8,7 @@ use crate::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace,
VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallB,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig},
@@ -19,7 +19,7 @@ pub fn bench_vec_znx_big_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add::{}", label);
let group_name: String = format!("vec_znx_big_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -65,7 +65,7 @@ pub fn bench_vec_znx_big_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_inplace::{}", label);
let group_name: String = format!("vec_znx_big_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -109,7 +109,7 @@ pub fn bench_vec_znx_big_add_small<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small::{}", label);
let group_name: String = format!("vec_znx_big_add_small::{label}");
let mut group = c.benchmark_group(group_name);
@@ -155,7 +155,7 @@ pub fn bench_vec_znx_big_add_small_inplace<B: Backend>(c: &mut Criterion, label:
where
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label);
let group_name: String = format!("vec_znx_big_add_small_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -199,7 +199,7 @@ pub fn bench_vec_znx_big_automorphism<B: Backend>(c: &mut Criterion, label: &str
where
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_automorphism::{}", label);
let group_name: String = format!("vec_znx_big_automorphism::{label}");
let mut group = c.benchmark_group(group_name);
@@ -244,7 +244,7 @@ where
Module<B>: VecZnxBigAutomorphismInplace<B> + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -289,7 +289,7 @@ pub fn bench_vec_znx_big_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_negate::{}", label);
let group_name: String = format!("vec_znx_big_negate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -332,7 +332,7 @@ pub fn bench_vec_znx_big_negate_inplace<B: Backend>(c: &mut Criterion, label: &s
where
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_negate_big_inplace::{}", label);
let group_name: String = format!("vec_znx_negate_big_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -374,7 +374,7 @@ where
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_big_normalize::{}", label);
let group_name: String = format!("vec_znx_big_normalize::{label}");
let mut group = c.benchmark_group(group_name);
@@ -389,7 +389,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -404,7 +404,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
}
black_box(());
}
@@ -423,7 +423,7 @@ pub fn bench_vec_znx_big_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub::{}", label);
let group_name: String = format!("vec_znx_big_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -464,17 +464,17 @@ where
group.finish();
}
pub fn bench_vec_znx_big_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_big_sub_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let group_name: String = format!("vec_znx_big_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
@@ -492,7 +492,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i);
module.vec_znx_big_sub_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -507,17 +507,17 @@ where
group.finish();
}
pub fn bench_vec_znx_big_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_big_sub_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let group_name: String = format!("vec_znx_big_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
@@ -535,7 +535,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i);
module.vec_znx_big_sub_negate_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -554,7 +554,7 @@ pub fn bench_vec_znx_big_sub_small_a<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_a::{}", label);
let group_name: String = format!("vec_znx_big_sub_small_a::{label}");
let mut group = c.benchmark_group(group_name);
@@ -599,7 +599,7 @@ pub fn bench_vec_znx_big_sub_small_b<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_b::{}", label);
let group_name: String = format!("vec_znx_big_sub_small_b::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -6,7 +6,7 @@ use rand::RngCore;
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc,
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
@@ -17,7 +17,7 @@ pub fn bench_vec_znx_dft_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add::{}", label);
let group_name: String = format!("vec_znx_dft_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -62,7 +62,7 @@ pub fn bench_vec_znx_dft_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -106,7 +106,7 @@ pub fn bench_vec_znx_dft_apply<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_apply::{}", label);
let group_name: String = format!("vec_znx_dft_apply::{label}");
let mut group = c.benchmark_group(group_name);
@@ -149,7 +149,7 @@ where
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_idft_apply::{}", label);
let group_name: String = format!("vec_znx_idft_apply::{label}");
let mut group = c.benchmark_group(group_name);
@@ -194,7 +194,7 @@ pub fn bench_vec_znx_idft_apply_tmpa<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label);
let group_name: String = format!("vec_znx_idft_apply_tmpa::{label}");
let mut group = c.benchmark_group(group_name);
@@ -235,7 +235,7 @@ pub fn bench_vec_znx_dft_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub::{}", label);
let group_name: String = format!("vec_znx_dft_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -276,17 +276,17 @@ where
group.finish();
}
pub fn bench_vec_znx_dft_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_dft_sub_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
@@ -305,7 +305,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i);
module.vec_znx_dft_sub_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -320,17 +320,17 @@ where
group.finish();
}
pub fn bench_vec_znx_dft_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_dft_sub_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubNegateInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_sub_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubNegateInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
@@ -349,7 +349,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i);
module.vec_znx_dft_sub_negate_inplace(&mut c, i, &a, i);
}
black_box(());
}

View File

@@ -17,7 +17,7 @@ where
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_prepare::{}", label);
let group_name: String = format!("vmp_prepare::{label}");
let mut group = c.benchmark_group(group_name);
@@ -76,7 +76,7 @@ where
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft::{}", label);
let group_name: String = format!("vmp_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
@@ -137,7 +137,7 @@ where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft::{}", label);
let group_name: String = format!("vmp_apply_dft_to_dft::{label}");
let mut group = c.benchmark_group(group_name);
@@ -200,7 +200,7 @@ where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label);
let group_name: String = format!("vmp_apply_dft_to_dft_add::{label}");
let mut group = c.benchmark_group(group_name);