mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
some cleaning
This commit is contained in:
@@ -4,23 +4,21 @@ use rlwe::{
|
|||||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||||
elem::Elem,
|
elem::Elem,
|
||||||
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
||||||
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes},
|
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
|
||||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
||||||
keys::SecretKey,
|
keys::SecretKey,
|
||||||
parameters::{Parameters, ParametersLiteral},
|
parameters::{Parameters, ParametersLiteral},
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
fn gadget_product_inplace(c: &mut Criterion) {
|
fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||||
fn gadget_product<'a>(
|
fn runner<'a>(
|
||||||
module: &'a Module,
|
module: &'a Module,
|
||||||
elem: &'a mut Elem<VecZnx>,
|
elem: &'a mut Elem<VecZnx>,
|
||||||
gadget_ct: &'a Ciphertext<VmpPMat>,
|
gadget_ct: &'a Ciphertext<VmpPMat>,
|
||||||
tmp_bytes: &'a mut [u8],
|
tmp_bytes: &'a mut [u8],
|
||||||
) -> Box<dyn FnMut() + 'a> {
|
) -> Box<dyn FnMut() + 'a> {
|
||||||
Box::new(move || {
|
Box::new(move || gadget_product_inplace::<true, _>(module, elem, gadget_ct, tmp_bytes))
|
||||||
gadget_product_inplace_thread_safe::<true, _>(module, elem, gadget_ct, tmp_bytes)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
|
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
|
||||||
@@ -111,7 +109,7 @@ fn gadget_product_inplace(c: &mut Criterion) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
||||||
gadget_product(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
|
runner(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes)
|
||||||
})];
|
})];
|
||||||
|
|
||||||
for (name, mut runner) in runners {
|
for (name, mut runner) in runners {
|
||||||
@@ -123,5 +121,5 @@ fn gadget_product_inplace(c: &mut Criterion) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, gadget_product_inplace);
|
criterion_group!(benches, bench_gadget_product_inplace);
|
||||||
criterion_main!(benches);
|
criterion_main!(benches);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use rlwe::{
|
|||||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||||
decryptor::decrypt_rlwe_thread_safe,
|
decryptor::decrypt_rlwe_thread_safe,
|
||||||
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
encryptor::{encrypt_grlwe_sk_thread_safe, encrypt_grlwe_sk_tmp_bytes},
|
||||||
evaluator::{gadget_product_inplace_thread_safe, gadget_product_tmp_bytes},
|
evaluator::{gadget_product_inplace, gadget_product_tmp_bytes},
|
||||||
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
key_generator::gen_switching_key_thread_safe_tmp_bytes,
|
||||||
keys::SecretKey,
|
keys::SecretKey,
|
||||||
parameters::{Parameters, ParametersLiteral},
|
parameters::{Parameters, ParametersLiteral},
|
||||||
@@ -112,12 +112,7 @@ fn main() {
|
|||||||
&mut tmp_bytes,
|
&mut tmp_bytes,
|
||||||
);
|
);
|
||||||
|
|
||||||
gadget_product_inplace_thread_safe::<true, _>(
|
gadget_product_inplace::<true, _>(params.module(), &mut ct.0, &gadget_ct, &mut tmp_bytes);
|
||||||
params.module(),
|
|
||||||
&mut ct.0,
|
|
||||||
&gadget_ct,
|
|
||||||
&mut tmp_bytes,
|
|
||||||
);
|
|
||||||
|
|
||||||
println!("ct.limbs()={}", ct.cols());
|
println!("ct.limbs()={}", ct.cols());
|
||||||
println!("gadget_ct.rows()={}", gadget_ct.rows());
|
println!("gadget_ct.rows()={}", gadget_ct.rows());
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ where
|
|||||||
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
|
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||||
let elem_size = T::bytes_of(n, limbs);
|
let elem_size = T::bytes_of(n, limbs);
|
||||||
let mut ptr: usize = 0;
|
let mut ptr: usize = 0;
|
||||||
println!("{} {} {}", size, elem_size, bytes.len());
|
|
||||||
(0..size).for_each(|_| {
|
(0..size).for_each(|_| {
|
||||||
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..]));
|
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..]));
|
||||||
ptr += elem_size
|
ptr += elem_size
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ pub fn gadget_product_tmp_bytes(
|
|||||||
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
|
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
|
pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
res: &mut Elem<T>,
|
res: &mut Elem<T>,
|
||||||
b: &Ciphertext<VmpPMat>,
|
b: &Ciphertext<VmpPMat>,
|
||||||
@@ -31,7 +31,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
|
|||||||
{
|
{
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *const T = res.at(1) as *const T;
|
let a_ptr: *const T = res.at(1) as *const T;
|
||||||
gadget_product_thread_safe::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
|
gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
|
|||||||
///
|
///
|
||||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||||
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
|
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
|
||||||
pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
|
pub fn gadget_product<const OVERWRITE: bool, T>(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
res: &mut Elem<T>,
|
res: &mut Elem<T>,
|
||||||
a: &T,
|
a: &T,
|
||||||
@@ -70,7 +70,7 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
|
|||||||
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
|
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
|
||||||
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
|
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
|
||||||
|
|
||||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
|
||||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
|
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
|
||||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||||
|
|
||||||
@@ -81,17 +81,10 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
|
|||||||
let mut res_big_c1: VecZnxBig =
|
let mut res_big_c1: VecZnxBig =
|
||||||
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
|
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
|
||||||
|
|
||||||
// a_dft <- DFT(a)
|
// tmp_a_dft <- DFT(a)
|
||||||
module.vec_znx_dft(&mut c1_dft, a, a.cols());
|
|
||||||
|
|
||||||
// (n x cols) <- (n x limbs=rows) x (rows x cols)
|
// (n x cols) <- (n x limbs=rows) x (rows x cols)
|
||||||
// res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1))
|
// res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
|
||||||
module.vmp_apply_dft_to_dft(
|
gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
|
||||||
&mut res_dft,
|
|
||||||
&c1_dft,
|
|
||||||
&b.0.value[0],
|
|
||||||
tmp_bytes_vmp_apply_dft,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
||||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
|
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
|
||||||
@@ -110,7 +103,25 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rgsw_product_thread_safe<T>(
|
pub fn gadget_product_core<T>(
|
||||||
|
module: &Module,
|
||||||
|
res_dft: &mut VecZnxDft,
|
||||||
|
a: &T,
|
||||||
|
b: &VmpPMat,
|
||||||
|
tmp_bytes_vmp_apply_dft: &mut [u8],
|
||||||
|
) where
|
||||||
|
T: VecZnxCommon<Owned = T>,
|
||||||
|
Elem<T>: ElemVecZnx<T>,
|
||||||
|
{
|
||||||
|
// res_dft <- DFT(a)
|
||||||
|
module.vec_znx_dft(res_dft, a, a.cols());
|
||||||
|
|
||||||
|
// (n x cols) <- (n x limbs=rows) x (rows x cols)
|
||||||
|
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
|
||||||
|
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rgsw_product<T>(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
res: &mut Elem<T>,
|
res: &mut Elem<T>,
|
||||||
a: &Ciphertext<T>,
|
a: &Ciphertext<T>,
|
||||||
|
|||||||
Reference in New Issue
Block a user