some cleaning

This commit is contained in:
Jean-Philippe Bossuat
2025-02-18 18:27:58 +01:00
parent 71f33f5983
commit 3937a43b08
4 changed files with 34 additions and 31 deletions

View File

@@ -4,23 +4,21 @@ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::Elem,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes},
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
};
use sampling::source::Source;
fn gadget_product_inplace(c: &mut Criterion) {
fn gadget_product<'a>(
fn bench_gadget_product_inplace(c: &mut Criterion) {
fn runner<'a>(
module: &'a Module,
elem: &'a mut Elem<VecZnx>,
gadget_ct: &'a Ciphertext<VmpPMat>,
tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> {
Box::new(move || {
gadget_product_inplace_thread_safe::<true, _>(module, elem, gadget_ct, tmp_bytes)
})
Box::new(move || gadget_product_inplace::<true, _>(module, elem, gadget_ct, tmp_bytes))
}
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
@@ -111,7 +109,7 @@ fn gadget_product_inplace(c: &mut Criterion) {
);
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
gadget_product(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
})];
for (name, mut runner) in runners {
@@ -123,5 +121,5 @@ fn gadget_product_inplace(c: &mut Criterion) {
}
}
criterion_group!(benches, gadget_product_inplace);
criterion_group!(benches, bench_gadget_product_inplace);
criterion_main!(benches);

View File

@@ -3,7 +3,7 @@ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
decryptor::decrypt_rlwe_thread_safe,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes},
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
@@ -112,12 +112,7 @@ fn main() {
&mut tmp_bytes,
);
gadget_product_inplace_thread_safe::<true, _>(
params.module(),
&mut ct.0,
&gadget_ct,
&mut tmp_bytes,
);
gadget_product_inplace::<true, _>(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes);
println!("ct.limbs()={}", ct.cols());
println!("gadget_ct.rows()={}", gadget_ct.rows());

View File

@@ -58,7 +58,6 @@ where
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = T::bytes_of(n, limbs);
let mut ptr: usize = 0;
println!("{} {} {}", size, elem_size, bytes.len());
(0..size).for_each(|_| {
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..]));
ptr += elem_size

View File

@@ -20,7 +20,7 @@ pub fn gadget_product_tmp_bytes(
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
}
pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
b: &Ciphertext<VmpPMat>,
@@ -31,7 +31,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
{
unsafe {
let a_ptr: *const T = res.at(1) as *const T;
gadget_product_thread_safe::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
}
}
@@ -49,7 +49,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
pub fn gadget_product<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
a: &T,
@@ -70,7 +70,7 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
@@ -81,17 +81,10 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let mut res_big_c1: VecZnxBig =
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
// a_dft <- DFT(a)
module.vec_znx_dft(&mut c1_dft, a, a.cols());
// tmp_a_dft <- DFT(a)
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft(
&mut res_dft,
&c1_dft,
&b.0.value[0],
tmp_bytes_vmp_apply_dft,
);
// res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
@@ -110,7 +103,25 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
}
}
pub fn rgsw_product_thread_safe<T>(
pub fn gadget_product_core<T>(
module: &Module,
res_dft: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
tmp_bytes_vmp_apply_dft: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
// res_dft <- DFT(a)
module.vec_znx_dft(res_dft, a, a.cols());
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
}
pub fn rgsw_product<T>(
module: &Module,
res: &mut Elem<T>,
a: &Ciphertext<T>,