use std::hint::black_box; use criterion::{BenchmarkId, Criterion}; use crate::{ api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace}, layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl}, reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero}, source::Source, }; pub fn vec_znx_sub(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxToMut, A: VecZnxToRef, B: VecZnxToRef, ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy, { let a: VecZnx<&[u8]> = a.to_ref(); let b: VecZnx<&[u8]> = b.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), res.n()); assert_eq!(b.n(), res.n()); } let res_size: usize = res.size(); let a_size: usize = a.size(); let b_size: usize = b.size(); if a_size <= b_size { let sum_size: usize = a_size.min(res_size); let cpy_size: usize = b_size.min(res_size); for j in 0..sum_size { ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); } for j in sum_size..cpy_size { ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j)); } for j in cpy_size..res_size { ZNXARI::znx_zero(res.at_mut(res_col, j)); } } else { let sum_size: usize = b_size.min(res_size); let cpy_size: usize = a_size.min(res_size); for j in 0..sum_size { ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); } for j in sum_size..cpy_size { ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j)); } for j in cpy_size..res_size { ZNXARI::znx_zero(res.at_mut(res_col, j)); } } } pub fn vec_znx_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, ZNXARI: ZnxSubABInplace, { let a: VecZnx<&[u8]> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), res.n()); } let res_size: usize = res.size(); let a_size: usize = a.size(); let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } } pub fn vec_znx_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, ZNXARI: ZnxSubBAInplace + ZnxNegateInplace, { let a: VecZnx<&[u8]> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), res.n()); } let res_size: usize = res.size(); let a_size: usize = a.size(); let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } for j in sum_size..res_size { ZNXARI::znx_negate_inplace(res.at_mut(res_col, j)); } } pub fn bench_vec_znx_sub(c: &mut Criterion, label: &str) where B: Backend + ModuleNewImpl + VecZnxSubImpl, { let group_name: String = format!("vec_znx_sub::{}", label); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: VecZnxSub + ModuleNew, { let n: usize = 1 << params[0]; let cols: usize = params[1]; let size: usize = params[2]; let module: Module = Module::::new(n as u64); let mut source: Source = Source::new([0u8; 32]); let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut b: VecZnx> = VecZnx::alloc(n, cols, size); let mut c: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 a.fill_uniform(50, &mut source); b.fill_uniform(50, &mut source); move || { for i in 0..cols { module.vec_znx_sub(&mut c, i, &a, i, &b, i); } black_box(()); } } for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); let mut runner = runner::(params); group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); } group.finish(); } pub fn bench_vec_znx_sub_ab_inplace(c: &mut Criterion, label: &str) where B: Backend + ModuleNewImpl + VecZnxSubABInplaceImpl, { let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: VecZnxSubABInplace + ModuleNew, { let n: usize = 1 << params[0]; let cols: usize = params[1]; let size: usize = params[2]; let module: Module = Module::::new(n as u64); let mut source: Source = Source::new([0u8; 32]); let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut b: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 a.fill_uniform(50, &mut source); b.fill_uniform(50, &mut source); move || { for i in 0..cols { module.vec_znx_sub_ab_inplace(&mut b, i, &a, i); } black_box(()); } } for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); let mut runner = runner::(params); group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); } group.finish(); } pub fn bench_vec_znx_sub_ba_inplace(c: &mut Criterion, label: &str) where B: Backend + ModuleNewImpl + VecZnxSubBAInplaceImpl, { let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: VecZnxSubBAInplace + ModuleNew, { let n: usize = 1 << params[0]; let cols: usize = params[1]; let size: usize = params[2]; let module: Module = Module::::new(n as u64); let mut source: Source = Source::new([0u8; 32]); let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut b: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 a.fill_uniform(50, &mut source); b.fill_uniform(50, &mut source); move || { for i in 0..cols { module.vec_znx_sub_ba_inplace(&mut b, i, &a, i); } black_box(()); } } for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); let mut runner = runner::(params); group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); } group.finish(); }