some cleaning

This commit is contained in:
Jean-Philippe Bossuat
2025-02-18 18:27:58 +01:00
parent 71f33f5983
commit 3937a43b08
4 changed files with 34 additions and 31 deletions

View File

@@ -4,23 +4,21 @@ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext}, ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::Elem, elem::Elem,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes, key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey, keys::SecretKey,
parameters::{Parameters, ParametersLiteral}, parameters::{Parameters, ParametersLiteral},
}; };
use sampling::source::Source; use sampling::source::Source;
fn gadget_product_inplace(c: &mut Criterion) { fn bench_gadget_product_inplace(c: &mut Criterion) {
fn gadget_product<'a>( fn runner<'a>(
module: &'a Module, module: &'a Module,
elem: &'a mut Elem<VecZnx>, elem: &'a mut Elem<VecZnx>,
gadget_ct: &'a Ciphertext<VmpPMat>, gadget_ct: &'a Ciphertext<VmpPMat>,
tmp_bytes: &'a mut [u8], tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> { ) -> Box<dyn FnMut() + 'a> {
Box::new(move || { Box::new(move || gadget_product_inplace::<true, _>(module, elem, gadget_ct, tmp_bytes))
gadget_product_inplace_thread_safe::<true, _>(module, elem, gadget_ct, tmp_bytes)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
@@ -111,7 +109,7 @@ fn gadget_product_inplace(c: &mut Criterion) {
); );
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), { let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
gadget_product(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes) runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
})]; })];
for (name, mut runner) in runners { for (name, mut runner) in runners {
@@ -123,5 +121,5 @@ fn gadget_product_inplace(c: &mut Criterion) {
} }
} }
criterion_group!(benches, gadget_product_inplace); criterion_group!(benches, bench_gadget_product_inplace);
criterion_main!(benches); criterion_main!(benches);

View File

@@ -3,7 +3,7 @@ use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext}, ciphertext::{Ciphertext, new_gadget_ciphertext},
decryptor::decrypt_rlwe_thread_safe, decryptor::decrypt_rlwe_thread_safe,
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes}, encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes}, evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
key_generator::gen_switching_key_thread_safe_tmp_bytes, key_generator::gen_switching_key_thread_safe_tmp_bytes,
keys::SecretKey, keys::SecretKey,
parameters::{Parameters, ParametersLiteral}, parameters::{Parameters, ParametersLiteral},
@@ -112,12 +112,7 @@ fn main() {
&mut tmp_bytes, &mut tmp_bytes,
); );
gadget_product_inplace_thread_safe::<true, _>( gadget_product_inplace::<true, _>(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes);
params.module(),
&mut ct.0,
&gadget_ct,
&mut tmp_bytes,
);
println!("ct.limbs()={}", ct.cols()); println!("ct.limbs()={}", ct.cols());
println!("gadget_ct.rows()={}", gadget_ct.rows()); println!("gadget_ct.rows()={}", gadget_ct.rows());

View File

@@ -58,7 +58,6 @@ where
let limbs: usize = (log_q + log_base2k - 1) / log_base2k; let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = T::bytes_of(n, limbs); let elem_size = T::bytes_of(n, limbs);
let mut ptr: usize = 0; let mut ptr: usize = 0;
println!("{} {} {}", size, elem_size, bytes.len());
(0..size).for_each(|_| { (0..size).for_each(|_| {
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..])); value.push(T::from_bytes(n, limbs, &mut bytes[ptr..]));
ptr += elem_size ptr += elem_size

View File

@@ -20,7 +20,7 @@ pub fn gadget_product_tmp_bytes(
+ 2 * module.bytes_of_vec_znx_dft(gct_cols) + 2 * module.bytes_of_vec_znx_dft(gct_cols)
} }
pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>( pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
module: &Module, module: &Module,
res: &mut Elem<T>, res: &mut Elem<T>,
b: &Ciphertext<VmpPMat>, b: &Ciphertext<VmpPMat>,
@@ -31,7 +31,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
{ {
unsafe { unsafe {
let a_ptr: *const T = res.at(1) as *const T; let a_ptr: *const T = res.at(1) as *const T;
gadget_product_thread_safe::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes); gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
} }
} }
@@ -49,7 +49,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
/// ///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i]) /// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs. /// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>( pub fn gadget_product<const OVERWRITE: bool, T>(
module: &Module, module: &Module,
res: &mut Elem<T>, res: &mut Elem<T>,
a: &T, a: &T,
@@ -70,7 +70,7 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft); let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft); let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft); let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft); let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
@@ -81,17 +81,10 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let mut res_big_c1: VecZnxBig = let mut res_big_c1: VecZnxBig =
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1); module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
// a_dft <- DFT(a) // tmp_a_dft <- DFT(a)
module.vec_znx_dft(&mut c1_dft, a, a.cols());
// (n x cols) <- (n x limbs=rows) x (rows x cols) // (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1)) // res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft( gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
&mut res_dft,
&c1_dft,
&b.0.value[0],
tmp_bytes_vmp_apply_dft,
);
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)]) // res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols); module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
@@ -110,7 +103,25 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
} }
} }
pub fn rgsw_product_thread_safe<T>( pub fn gadget_product_core<T>(
module: &Module,
res_dft: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
tmp_bytes_vmp_apply_dft: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
// res_dft <- DFT(a)
module.vec_znx_dft(res_dft, a, a.cols());
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
}
pub fn rgsw_product<T>(
module: &Module, module: &Module,
res: &mut Elem<T>, res: &mut Elem<T>,
a: &Ciphertext<T>, a: &Ciphertext<T>,