fixed gadget product bench

This commit is contained in:
Jean-Philippe Bossuat
2025-02-24 10:26:12 +01:00
parent 3634ab7746
commit 5d3dfe0f3c
3 changed files with 40 additions and 22 deletions

View File

@@ -1,12 +1,13 @@
/* use base2k::{
use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8}; FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
VmpPMat, alloc_aligned_u8,
};
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use rlwe::{ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext}, ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::{Elem, ElemCommon}, elem::ElemCommon,
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
gadget_product::{gadget_product_core, gadget_product_tmp_bytes}, gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey, keys::SecretKey,
parameters::{Parameters, ParametersLiteral}, parameters::{Parameters, ParametersLiteral},
}; };
@@ -15,12 +16,18 @@ use sampling::source::Source;
fn bench_gadget_product_inplace(c: &mut Criterion) { fn bench_gadget_product_inplace(c: &mut Criterion) {
fn runner<'a>( fn runner<'a>(
module: &'a Module, module: &'a Module,
elem: &'a mut Elem<VecZnx>, res_dft_0: &'a mut VecZnxDft,
gadget_ct: &'a Ciphertext<VmpPMat>, res_dft_1: &'a mut VecZnxDft,
a: &'a VecZnx,
a_cols: usize,
b: &'a Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &'a mut [u8], tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> { ) -> Box<dyn FnMut() + 'a> {
Box::new(move || { Box::new(move || {
gadget_product_inplace::<true, _>(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes) gadget_product_core(
module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes,
);
}) })
} }
@@ -41,14 +48,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
let params: Parameters = Parameters::new::<FFT64>(&params_lit); let params: Parameters = Parameters::new::<FFT64>(&params_lit);
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8( let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
| gen_switching_key_thread_safe_tmp_bytes(
params.module(),
params.log_base2k(),
params.cols_q(),
params.log_q(),
)
| gadget_product_tmp_bytes( | gadget_product_tmp_bytes(
params.module(), params.module(),
params.log_base2k(), params.log_base2k(),
@@ -89,7 +89,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
params.log_qp(), params.log_qp(),
); );
encrypt_grlwe_sk_thread_safe( encrypt_grlwe_sk(
params.module(), params.module(),
&mut gadget_ct, &mut gadget_ct,
&sk0.0, &sk0.0,
@@ -111,8 +111,28 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
&mut tmp_bytes, &mut tmp_bytes,
); );
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
params
.module()
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
let a_cols: usize = a.cols();
let b_cols: usize = gadget_ct.cols();
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), { let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes) runner(
params.module(),
&mut res_dft_0,
&mut res_dft_1,
&mut a,
a_cols,
&gadget_ct,
b_cols,
&mut tmp_bytes,
)
})]; })];
for (name, mut runner) in runners { for (name, mut runner) in runners {
@@ -126,4 +146,3 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
criterion_group!(benches, bench_gadget_product_inplace); criterion_group!(benches, bench_gadget_product_inplace);
criterion_main!(benches); criterion_main!(benches);
*/

View File

@@ -6,7 +6,7 @@ use crate::plaintext::Plaintext;
use base2k::sampling::Sampling; use base2k::sampling::Sampling;
use base2k::{ use base2k::{
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow,
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};

View File

@@ -53,7 +53,7 @@ impl Parameters {
/// ///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols. /// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
pub fn gadget_product_core<const OVERWRITE: bool, T>( pub fn gadget_product_core<T>(
module: &Module, module: &Module,
res_dft_0: &mut VecZnxDft, res_dft_0: &mut VecZnxDft,
res_dft_1: &mut VecZnxDft, res_dft_1: &mut VecZnxDft,
@@ -108,7 +108,6 @@ mod test {
VecZnxDftOps, VecZnxOps, VmpPMat, VecZnxDftOps, VecZnxOps, VmpPMat,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};
use std::cmp::min;
#[test] #[test]
fn test_gadget_product_core() { fn test_gadget_product_core() {
@@ -232,7 +231,7 @@ mod test {
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
gadget_product_core::<true, _>( gadget_product_core(
params.module(), params.module(),
&mut res_dft_0, &mut res_dft_0,
&mut res_dft_1, &mut res_dft_1,