some cleaning

This commit is contained in:
Jean-Philippe Bossuat
2025-02-18 18:27:58 +01:00
parent 71f33f5983
commit 3937a43b08
4 changed files with 34 additions and 31 deletions

View File

@@ -58,7 +58,6 @@ where
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = T::bytes_of(n, limbs);
let mut ptr: usize = 0;
println!("{} {} {}", size, elem_size, bytes.len());
(0..size).for_each(|_| {
value.push(T::from_bytes(n, limbs, &mut bytes[ptr..]));
ptr += elem_size

View File

@@ -20,7 +20,7 @@ pub fn gadget_product_tmp_bytes(
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
}
pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
pub fn gadget_product_inplace<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
b: &Ciphertext<VmpPMat>,
@@ -31,7 +31,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
{
unsafe {
let a_ptr: *const T = res.at(1) as *const T;
gadget_product_thread_safe::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
gadget_product::<OVERWRITE, T>(module, res, &*a_ptr, b, tmp_bytes);
}
}
@@ -49,7 +49,7 @@ pub fn gadget_product_inplace_thread_safe<const OVERWRITE: bool, T>(
///
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
/// = (cs + m * a + e, c) with min(res_limbs, b_cols) limbs.
pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
pub fn gadget_product<const OVERWRITE: bool, T>(
module: &Module,
res: &mut Elem<T>,
a: &T,
@@ -70,7 +70,7 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let (tmp_bytes_vmp_apply_dft, tmp_bytes) = tmp_bytes.split_at_mut(bytes_vmp_apply_dft);
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(bytes_vec_znx_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut tmp_a_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_c1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
@@ -81,17 +81,10 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
let mut res_big_c1: VecZnxBig =
module.new_vec_znx_big_from_bytes(cols >> 1, tmp_bytes_res_dft_c1);
// a_dft <- DFT(a)
module.vec_znx_dft(&mut c1_dft, a, a.cols());
// tmp_a_dft <- DFT(a)
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] DFT(a) x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft(
&mut res_dft,
&c1_dft,
&b.0.value[0],
tmp_bytes_vmp_apply_dft,
);
// res_dft[a * (G0|G1)] <- sum[rows] tmp_a_dft x (DFT(G0)|DFT(G1))
gadget_product_core(module, &mut res_dft, a, b.at(0), tmp_bytes_vmp_apply_dft);
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
@@ -110,7 +103,25 @@ pub fn gadget_product_thread_safe<const OVERWRITE: bool, T>(
}
}
pub fn rgsw_product_thread_safe<T>(
pub fn gadget_product_core<T>(
module: &Module,
res_dft: &mut VecZnxDft,
a: &T,
b: &VmpPMat,
tmp_bytes_vmp_apply_dft: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemVecZnx<T>,
{
// res_dft <- DFT(a)
module.vec_znx_dft(res_dft, a, a.cols());
// (n x cols) <- (n x limbs=rows) x (rows x cols)
// res_dft[a * (G0|G1)] <- sum[rows] res_dft x (DFT(G0)|DFT(G1))
module.vmp_apply_dft_to_dft_inplace(res_dft, b, tmp_bytes_vmp_apply_dft);
}
pub fn rgsw_product<T>(
module: &Module,
res: &mut Elem<T>,
a: &Ciphertext<T>,