mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
fixed gadget product bench
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
/*
|
||||
use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8};
|
||||
use base2k::{
|
||||
FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||
VmpPMat, alloc_aligned_u8,
|
||||
};
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use rlwe::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::{Elem, ElemCommon},
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
|
||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
};
|
||||
@@ -15,12 +16,18 @@ use sampling::source::Source;
|
||||
fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
fn runner<'a>(
|
||||
module: &'a Module,
|
||||
elem: &'a mut Elem<VecZnx>,
|
||||
gadget_ct: &'a Ciphertext<VmpPMat>,
|
||||
res_dft_0: &'a mut VecZnxDft,
|
||||
res_dft_1: &'a mut VecZnxDft,
|
||||
a: &'a VecZnx,
|
||||
a_cols: usize,
|
||||
b: &'a Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &'a mut [u8],
|
||||
) -> Box<dyn FnMut() + 'a> {
|
||||
Box::new(move || {
|
||||
gadget_product_inplace::<true, _>(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes)
|
||||
gadget_product_core(
|
||||
module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,14 +48,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
let params: Parameters = Parameters::new::<FFT64>(¶ms_lit);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
| gen_switching_key_thread_safe_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.cols_q(),
|
||||
params.log_q(),
|
||||
)
|
||||
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
| gadget_product_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
@@ -89,7 +89,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
encrypt_grlwe_sk_thread_safe(
|
||||
encrypt_grlwe_sk(
|
||||
params.module(),
|
||||
&mut gadget_ct,
|
||||
&sk0.0,
|
||||
@@ -111,8 +111,28 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||
|
||||
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
|
||||
params
|
||||
.module()
|
||||
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
|
||||
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = gadget_ct.cols();
|
||||
|
||||
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
||||
runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
|
||||
runner(
|
||||
params.module(),
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
&mut a,
|
||||
a_cols,
|
||||
&gadget_ct,
|
||||
b_cols,
|
||||
&mut tmp_bytes,
|
||||
)
|
||||
})];
|
||||
|
||||
for (name, mut runner) in runners {
|
||||
@@ -126,4 +146,3 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
|
||||
criterion_group!(benches, bench_gadget_product_inplace);
|
||||
criterion_main!(benches);
|
||||
*/
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::plaintext::Plaintext;
|
||||
use base2k::sampling::Sampling;
|
||||
use base2k::{
|
||||
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow,
|
||||
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut,
|
||||
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps,
|
||||
};
|
||||
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
@@ -53,7 +53,7 @@ impl Parameters {
|
||||
///
|
||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
||||
pub fn gadget_product_core<const OVERWRITE: bool, T>(
|
||||
pub fn gadget_product_core<T>(
|
||||
module: &Module,
|
||||
res_dft_0: &mut VecZnxDft,
|
||||
res_dft_1: &mut VecZnxDft,
|
||||
@@ -108,7 +108,6 @@ mod test {
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
use std::cmp::min;
|
||||
|
||||
#[test]
|
||||
fn test_gadget_product_core() {
|
||||
@@ -232,7 +231,7 @@ mod test {
|
||||
|
||||
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
|
||||
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
|
||||
gadget_product_core::<true, _>(
|
||||
gadget_product_core(
|
||||
params.module(),
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
|
||||
Reference in New Issue
Block a user