From 3937a43b08adb533d4ea7ee5bbe6497bc6bb774b Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 18 Feb 2025 18:27:58 +0100 Subject: [PATCH] some cleaning --- rlwe/benches/gadget_product.rs | 14 +++++------ rlwe/examples/gadget_product.rs | 9 ++------ rlwe/src/elem.rs | 1 - rlwe/src/evaluator.rs | 41 +++++++++++++++++++++------------ 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index fc75341..b91b6c0 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -4,23 +4,21 @@ use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, elem::Elem, encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, - evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, + evaluator::{gadget_product_inplace, gadget_product_tmp_bytes}, key_generator::gen_switching_key_thread_safe_tmp_bytes, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, }; use sampling::source::Source; -fn gadget_product_inplace(c: &mut Criterion) { - fn gadget_product<'a>( +fn bench_gadget_product_inplace(c: &mut Criterion) { + fn runner<'a>( module: &'a Module, elem: &'a mut Elem, gadget_ct: &'a Ciphertext, tmp_bytes: &'a mut [u8], ) -> Box { - Box::new(move || { - gadget_product_inplace_thread_safe::(module, elem, gadget_ct, tmp_bytes) - }) + Box::new(move || gadget_product_inplace::(module, elem, gadget_ct, tmp_bytes)) } let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = @@ -111,7 +109,7 @@ fn gadget_product_inplace(c: &mut Criterion) { ); let runners: [(String, Box); 1] = [(format!("gadget_product"), { - gadget_product(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes) + runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes) })]; for (name, mut runner) in runners { @@ -123,5 +121,5 @@ fn gadget_product_inplace(c: &mut Criterion) { } } -criterion_group!(benches, gadget_product_inplace); +criterion_group!(benches, bench_gadget_product_inplace); criterion_main!(benches); diff --git a/rlwe/examples/gadget_product.rs b/rlwe/examples/gadget_product.rs index 40d94a9..f9e205f 100644 --- a/rlwe/examples/gadget_product.rs +++ b/rlwe/examples/gadget_product.rs @@ -3,7 +3,7 @@ use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, decryptor::decrypt_rlwe_thread_safe, encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, - evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, + evaluator::{gadget_product_inplace, gadget_product_tmp_bytes}, key_generator::gen_switching_key_thread_safe_tmp_bytes, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, @@ -112,12 +112,7 @@ fn main() { &mut tmp_bytes, ); - gadget_product_inplace_thread_safe::( - params.module(), - &mut ct.0, - &gadget_ct, - &mut tmp_bytes, - ); + gadget_product_inplace::(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes); println!("ct.limbs()={}", ct.cols()); println!("gadget_ct.rows()={}", gadget_ct.rows()); diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 31ce48e..5c1348d 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -58,7 +58,6 @@ where let limbs: usize = (log_q + log_base2k - 1) / log_base2k; let elem_size = T::bytes_of(n, limbs); let mut ptr: usize = 0; - println!("{} {} {}", size, elem_size, bytes.len()); (0..size).for_each(|_| { value.push(T::from_bytes(n, limbs, &mut bytes[ptr..])); ptr += elem_size diff --git a/rlwe/src/evaluator.rs b/rlwe/src/evaluator.rs index e224af2..30dffc4 100644 --- a/rlwe/src/evaluator.rs +++ b/rlwe/src/evaluator.rs @@ -20,7 +20,7 @@ pub fn gadget_product_tmp_bytes( + 2 * module.bytes_of_vec_znx_dft(gct_cols) } -pub fn gadget_product_inplace_thread_safe( +pub fn gadget_product_inplace( module: &Module, res: &mut Elem, b: &Ciphertext, @@ -31,7 +31,7 @@ pub fn gadget_product_inplace_thread_safe( { unsafe { let a_ptr: *const T = res.at(1) as *const T; - gadget_product_thread_safe::(module, res, &*a_ptr, b, tmp_bytes); + gadget_product::(module, res, &*a_ptr, b, tmp_bytes); } } @@ -49,7 +49,7 @@ pub fn gadget_product_inplace_thread_safe( /// /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs. -pub fn gadget_product_thread_safe( +pub fn gadget_product( module: &Module, res: &mut Elem, a: &T, @@ -70,7 +70,7 @@ pub fn gadget_product_thread_safe( let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft); let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); + let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); @@ -81,17 +81,10 @@ pub fn gadget_product_thread_safe( let mut res_big_c1: VecZnxBig = module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1); - // a_dft <- DFT(a) - module.vec_znx_dft(&mut c1_dft, a, a.cols()); - + // tmp_a_dft <- DFT(a) // (n x cols) <- (n x limbs=rows) x (rows x cols) - // res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1)) - module.vmp_apply_dft_to_dft( - &mut res_dft, - &c1_dft, - &b.0.value[0], - tmp_bytes_vmp_apply_dft, - ); + // res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1)) + gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft); // res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); @@ -110,7 +103,25 @@ pub fn gadget_product_thread_safe( } } -pub fn rgsw_product_thread_safe( +pub fn gadget_product_core( + module: &Module, + res_dft: &mut VecZnxDft, + a: &T, + b: &VmpPMat, + tmp_bytes_vmp_apply_dft: &mut [u8], +) where + T: VecZnxCommon, + Elem: ElemVecZnx, +{ + // res_dft <- DFT(a) + module.vec_znx_dft(res_dft, a, a.cols()); + + // (n x cols) <- (n x limbs=rows) x (rows x cols) + // res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1)) + module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft); +} + +pub fn rgsw_product( module: &Module, res: &mut Elem, a: &Ciphertext,