From 5d3dfe0f3cedf1df79294e3d98c02acac5b54066 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 24 Feb 2025 10:26:12 +0100 Subject: [PATCH] fixed gadget product bench --- rlwe/benches/gadget_product.rs | 55 +++++++++++++++++++++++----------- rlwe/src/encryptor.rs | 2 +- rlwe/src/gadget_product.rs | 5 ++-- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 5cedf5c..10ad9f9 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,12 +1,13 @@ -/* -use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8}; +use base2k::{ + FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, + VmpPMat, alloc_aligned_u8, +}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::{Elem, ElemCommon}, + elem::ElemCommon, encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, gadget_product::{gadget_product_core, gadget_product_tmp_bytes}, - key_generator::gen_switching_key_thread_safe_tmp_bytes, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, }; @@ -15,12 +16,18 @@ use sampling::source::Source; fn bench_gadget_product_inplace(c: &mut Criterion) { fn runner<'a>( module: &'a Module, - elem: &'a mut Elem, - gadget_ct: &'a Ciphertext, + res_dft_0: &'a mut VecZnxDft, + res_dft_1: &'a mut VecZnxDft, + a: &'a VecZnx, + a_cols: usize, + b: &'a Ciphertext, + b_cols: usize, tmp_bytes: &'a mut [u8], ) -> Box { Box::new(move || { - gadget_product_inplace::(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes) + gadget_product_core( + module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes, + ); }) } @@ -41,14 +48,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let params: Parameters = Parameters::new::(¶ms_lit); let mut tmp_bytes: Vec = alloc_aligned_u8( - params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q()) - | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) - | gen_switching_key_thread_safe_tmp_bytes( - params.module(), - params.log_base2k(), - params.cols_q(), - params.log_q(), - ) + params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) | gadget_product_tmp_bytes( params.module(), params.log_base2k(), @@ -89,7 +89,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { params.log_qp(), ); - encrypt_grlwe_sk_thread_safe( + encrypt_grlwe_sk( params.module(), &mut gadget_ct, &sk0.0, @@ -111,8 +111,28 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { &mut tmp_bytes, ); + let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); + let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); + + let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + params + .module() + .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); + + let a_cols: usize = a.cols(); + let b_cols: usize = gadget_ct.cols(); + let runners: [(String, Box); 1] = [(format!("gadget_product"), { - runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes) + runner( + params.module(), + &mut res_dft_0, + &mut res_dft_1, + &mut a, + a_cols, + &gadget_ct, + b_cols, + &mut tmp_bytes, + ) })]; for (name, mut runner) in runners { @@ -126,4 +146,3 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { criterion_group!(benches, bench_gadget_product_inplace); criterion_main!(benches); -*/ diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 7156d52..cb69944 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -6,7 +6,7 @@ use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, - VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut, + VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, }; use sampling::source::{Source, new_seed}; diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index b1551fd..f977983 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -53,7 +53,7 @@ impl Parameters { /// /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// = (cs + m * a + e, c) with min(res_cols, b_cols) cols. -pub fn gadget_product_core( +pub fn gadget_product_core( module: &Module, res_dft_0: &mut VecZnxDft, res_dft_1: &mut VecZnxDft, @@ -108,7 +108,6 @@ mod test { VecZnxDftOps, VecZnxOps, VmpPMat, }; use sampling::source::{Source, new_seed}; - use std::cmp::min; #[test] fn test_gadget_product_core() { @@ -232,7 +231,7 @@ mod test { // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) - gadget_product_core::( + gadget_product_core( params.module(), &mut res_dft_0, &mut res_dft_1,