mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
fixed gadget product bench
This commit is contained in:
@@ -1,12 +1,13 @@
|
|||||||
/*
|
use base2k::{
|
||||||
use base2k::{FFT64, Module, SvpPPolOps, VecZnx, VmpPMat, alloc_aligned_u8};
|
FFT64, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||||
|
VmpPMat, alloc_aligned_u8,
|
||||||
|
};
|
||||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||||
use rlwe::{
|
use rlwe::{
|
||||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||||
elem::{Elem, ElemCommon},
|
elem::ElemCommon,
|
||||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||||
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
|
gadget_product::{gadget_product_core, gadget_product_tmp_bytes},
|
||||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
|
||||||
keys::SecretKey,
|
keys::SecretKey,
|
||||||
parameters::{Parameters, ParametersLiteral},
|
parameters::{Parameters, ParametersLiteral},
|
||||||
};
|
};
|
||||||
@@ -15,12 +16,18 @@ use sampling::source::Source;
|
|||||||
fn bench_gadget_product_inplace(c: &mut Criterion) {
|
fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||||
fn runner<'a>(
|
fn runner<'a>(
|
||||||
module: &'a Module,
|
module: &'a Module,
|
||||||
elem: &'a mut Elem<VecZnx>,
|
res_dft_0: &'a mut VecZnxDft,
|
||||||
gadget_ct: &'a Ciphertext<VmpPMat>,
|
res_dft_1: &'a mut VecZnxDft,
|
||||||
|
a: &'a VecZnx,
|
||||||
|
a_cols: usize,
|
||||||
|
b: &'a Ciphertext<VmpPMat>,
|
||||||
|
b_cols: usize,
|
||||||
tmp_bytes: &'a mut [u8],
|
tmp_bytes: &'a mut [u8],
|
||||||
) -> Box<dyn FnMut() + 'a> {
|
) -> Box<dyn FnMut() + 'a> {
|
||||||
Box::new(move || {
|
Box::new(move || {
|
||||||
gadget_product_inplace::<true, _>(module, elem, gadget_ct, elem.cols() + 1, tmp_bytes)
|
gadget_product_core(
|
||||||
|
module, res_dft_0, res_dft_1, a, a_cols, b, b_cols, tmp_bytes,
|
||||||
|
);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,14 +48,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
let params: Parameters = Parameters::new::<FFT64>(¶ms_lit);
|
let params: Parameters = Parameters::new::<FFT64>(¶ms_lit);
|
||||||
|
|
||||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||||
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
|
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
|
||||||
| gen_switching_key_thread_safe_tmp_bytes(
|
|
||||||
params.module(),
|
|
||||||
params.log_base2k(),
|
|
||||||
params.cols_q(),
|
|
||||||
params.log_q(),
|
|
||||||
)
|
|
||||||
| gadget_product_tmp_bytes(
|
| gadget_product_tmp_bytes(
|
||||||
params.module(),
|
params.module(),
|
||||||
params.log_base2k(),
|
params.log_base2k(),
|
||||||
@@ -89,7 +89,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
params.log_qp(),
|
params.log_qp(),
|
||||||
);
|
);
|
||||||
|
|
||||||
encrypt_grlwe_sk_thread_safe(
|
encrypt_grlwe_sk(
|
||||||
params.module(),
|
params.module(),
|
||||||
&mut gadget_ct,
|
&mut gadget_ct,
|
||||||
&sk0.0,
|
&sk0.0,
|
||||||
@@ -111,8 +111,28 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
&mut tmp_bytes,
|
&mut tmp_bytes,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||||
|
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
|
||||||
|
|
||||||
|
let mut a: VecZnx = params.module().new_vec_znx(params.cols_q());
|
||||||
|
params
|
||||||
|
.module()
|
||||||
|
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
|
||||||
|
|
||||||
|
let a_cols: usize = a.cols();
|
||||||
|
let b_cols: usize = gadget_ct.cols();
|
||||||
|
|
||||||
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
||||||
runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
|
runner(
|
||||||
|
params.module(),
|
||||||
|
&mut res_dft_0,
|
||||||
|
&mut res_dft_1,
|
||||||
|
&mut a,
|
||||||
|
a_cols,
|
||||||
|
&gadget_ct,
|
||||||
|
b_cols,
|
||||||
|
&mut tmp_bytes,
|
||||||
|
)
|
||||||
})];
|
})];
|
||||||
|
|
||||||
for (name, mut runner) in runners {
|
for (name, mut runner) in runners {
|
||||||
@@ -126,4 +146,3 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
|
|||||||
|
|
||||||
criterion_group!(benches, bench_gadget_product_inplace);
|
criterion_group!(benches, bench_gadget_product_inplace);
|
||||||
criterion_main!(benches);
|
criterion_main!(benches);
|
||||||
*/
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use crate::plaintext::Plaintext;
|
|||||||
use base2k::sampling::Sampling;
|
use base2k::sampling::Sampling;
|
||||||
use base2k::{
|
use base2k::{
|
||||||
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow,
|
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow,
|
||||||
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut,
|
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps,
|
||||||
};
|
};
|
||||||
|
|
||||||
use sampling::source::{Source, new_seed};
|
use sampling::source::{Source, new_seed};
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ impl Parameters {
|
|||||||
///
|
///
|
||||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||||
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
||||||
pub fn gadget_product_core<const OVERWRITE: bool, T>(
|
pub fn gadget_product_core<T>(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
res_dft_0: &mut VecZnxDft,
|
res_dft_0: &mut VecZnxDft,
|
||||||
res_dft_1: &mut VecZnxDft,
|
res_dft_1: &mut VecZnxDft,
|
||||||
@@ -108,7 +108,6 @@ mod test {
|
|||||||
VecZnxDftOps, VecZnxOps, VmpPMat,
|
VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||||
};
|
};
|
||||||
use sampling::source::{Source, new_seed};
|
use sampling::source::{Source, new_seed};
|
||||||
use std::cmp::min;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_gadget_product_core() {
|
fn test_gadget_product_core() {
|
||||||
@@ -232,7 +231,7 @@ mod test {
|
|||||||
|
|
||||||
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
|
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
|
||||||
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
|
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
|
||||||
gadget_product_core::<true, _>(
|
gadget_product_core(
|
||||||
params.module(),
|
params.module(),
|
||||||
&mut res_dft_0,
|
&mut res_dft_0,
|
||||||
&mut res_dft_1,
|
&mut res_dft_1,
|
||||||
|
|||||||
Reference in New Issue
Block a user