From 917a4724375166e3f504096fd40178cc33e2c849 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 18:14:16 +0200 Subject: [PATCH] wip: change of approach, enables to select columns on which to operate --- base2k/examples/rlwe_encrypt.rs | 65 +-- base2k/src/commons.rs | 10 +- base2k/src/internals.rs | 96 ---- base2k/src/scalar_znx_dft.rs | 12 +- base2k/src/vec_znx.rs | 4 +- base2k/src/vec_znx_big_ops.rs | 6 +- base2k/src/vec_znx_ops.rs | 795 ++++++++------------------------ 7 files changed, 250 insertions(+), 738 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 803f371..395fdf6 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxLayout, alloc_aligned, + VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -13,13 +13,11 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(1)); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(2)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, ct_size); - // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); s.fill_ternary_prob(0.5, &mut source); @@ -30,47 +28,50 @@ fn main() { // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); - // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, ct_size); - module.fill_uniform(log_base2k, &mut a, 0, ct_size, &mut source); + // ct = (c0, c1) + let mut ct: VecZnx = module.new_vec_znx(2, ct_size); + + // Fill c1 with random values + module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, ct.size()); - // Applies buf_dft <- s * a - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + // Applies buf_dft <- s * c1 + module.svp_apply_dft( + &mut buf_dft, // DFT(c1 * s) + &s_ppol, + &ct, + 1, // c1 + ); - // Alias scratch space + // Alias scratch space (VecZnxDftis always at least as big as VecZnxBig) let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - // buf_big <- IDFT(buf_dft) (not normalized) + // BIG(c1 * s) <- IDFT(DFT(c1 * s)) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - println!("{:?}", buf_big.raw()); - + // m <- (0) let mut m: VecZnx = module.new_vec_znx(1, msg_size); - let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); - - // m m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); - // buf_big <- m - buf_big + // m - BIG(c1 * s) module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); - println!("{:?}", buf_big.raw()); + // c0 <- m - BIG(c1 * s) + module.vec_znx_big_normalize(log_base2k, &mut ct, &buf_big, &mut carry); - // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); - b.print(n); + ct.print(ct.sl()); + + // (c0 + e, c1) module.add_normal( log_base2k, - &mut b, - 0, + &mut ct, + 0, // c0 log_base2k * ct_size, &mut source, 3.2, @@ -79,16 +80,16 @@ fn main() { // Decrypt - // buf_big <- a * s - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + // DFT(c1 * s) + module.svp_apply_dft(&mut buf_dft, &s_ppol, &ct, 1); + // BIG(c1 * s) = IDFT(DFT(c1 * s)) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // buf_big <- a * s + b - module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + // BIG(c1 * s) + c0 + module.vec_znx_big_add_small_inplace(&mut buf_big, &ct); - println!("raw: {:?}", &buf_big.raw()); - - // res <- normalize(buf_big) + // m + e <- BIG(c1 * s + c0) + let mut res: VecZnx = module.new_vec_znx(1, ct_size); module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); // have = m * 2^{log_scale} + e diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 969897d..d5f60ee 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -81,12 +81,12 @@ pub trait ZnxLayout: ZnxInfos { } /// Returns non-mutable reference to the (i, j)-th small polynomial. - fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] { + fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } } /// Returns mutable reference to the (i, j)-th small polynomial. - fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { + fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } @@ -219,7 +219,7 @@ pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { n * cols * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, a: &T) +pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) where ::Scalar: IntegerType, { @@ -237,8 +237,8 @@ where (0..size).for_each(|i| { izip!( - a.at_limb(i).iter().step_by(gap_in), - b.at_limb_mut(i).iter_mut().step_by(gap_out) + a.at(col_a, i).iter().step_by(gap_in), + b.at_mut(col_b, i).iter_mut().step_by(gap_out) ) .for_each(|(x_in, x_out)| *x_out = *x_in); }); diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs index d7b08dc..f2fbe3b 100644 --- a/base2k/src/internals.rs +++ b/base2k/src/internals.rs @@ -2,102 +2,6 @@ use std::cmp::{max, min}; use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; -pub(crate) fn znx_post_process_ternary_op(c: &mut C, a: &A, b: &B) -where - C: ZnxBasics + ZnxLayout, - A: ZnxBasics + ZnxLayout, - B: ZnxBasics + ZnxLayout, - C::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_ne!(a.as_ptr(), b.as_ptr()); - assert_ne!(b.as_ptr(), c.as_ptr()); - assert_ne!(a.as_ptr(), c.as_ptr()); - } - - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - - let min_ab_cols: usize = min(a_cols, b_cols); - let max_ab_cols: usize = max(a_cols, b_cols); - - // Copies shared shared cols between (c, max(a, b)) - if a_cols != b_cols { - if a_cols > b_cols { - let min_size = min(c.size(), a.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } else { - let min_size = min(c.size(), b.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } - } - - // Zeroes the cols of c > max(a, b). - if c_cols > max_ab_cols { - (max_ab_cols..c_cols).for_each(|i| { - (0..c.size()).for_each(|j| { - c.zero_at(i, j); - }) - }); - } -} - -#[inline(always)] -pub fn apply_binary_op( - module: &Module, - c: &mut C, - a: &A, - b: &B, - op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]), -) where - BE: Backend, - C: ZnxBasics + ZnxLayout, - A: ZnxBasics + ZnxLayout, - B: ZnxBasics + ZnxLayout, - C::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - #[inline(always)] pub fn apply_unary_op( module: &Module, diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index cfe2f45..474135b 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -230,7 +230,7 @@ pub trait ScalarZnxDftOps { /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx); + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); } impl ScalarZnxDftOps for Module { @@ -261,16 +261,16 @@ impl ScalarZnxDftOps for Module { unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { unsafe { svp::svp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, + c.size() as u64, a.ptr as *const svp_ppol_t, - b.as_ptr(), - b.cols() as u64, - b.n() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, ) } } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 1bb8ab3..53aeb39 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -193,8 +193,8 @@ impl VecZnx { normalize(log_base2k, self, carry) } - pub fn switch_degree(&self, a: &mut Self) { - switch_degree(a, self) + pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { + switch_degree(a, col_a, self, col) } // Prints the first `n` coefficients of each limb diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 4b4e54e..c87c95d 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -232,9 +232,9 @@ impl VecZnxBigOps for Module { } let a_size: usize = a.size(); - let b_size: usize = b.sl(); - let a_sl: usize = a.size(); - let b_sl: usize = a.sl(); + let b_size: usize = b.size(); + let a_sl: usize = a.sl(); + let b_sl: usize = b.sl(); let a_cols: usize = a.cols(); let b_cols: usize = b.cols(); let min_cols: usize = min(a_cols, b_cols); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index c7c8d85..9f2d43a 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,9 +1,5 @@ -use std::cmp::min; - -use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -43,62 +39,70 @@ pub trait VecZnxOps { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + fn vec_znx_normalize_tmp_bytes(&self) -> usize; - /// Normalizes `a` and stores the result into `b`. - fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]); + /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. + fn vec_znx_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnx, + col_a: usize, + tmp_bytes: &mut [u8], + ); - /// Normalizes `a` and stores the result into `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]); + /// Normalizes the selected column of `a`. + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]); - /// Adds `a` to `b` and write the result on `c`. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. + fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); - /// Adds `a` to `b` and write the result on `b`. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Subtracts `b` to `a` and write the result on `c`. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. + fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); - /// Subtracts `a` to `b` and write the result on `b`. - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Subtracts the selected column of `a` to the selected column of `res`. + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Subtracts `b` to `a` and write the result on `b`. - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); + /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - // Negates `a` and stores the result on `b`. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + // Negates the selected column of `a` and stores the result on the selected column of `res`. + fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Negages `a` and stores the result on `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + /// Negates the selected column of `a`. + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize); - /// Multiplies `a` by X^k and stores the result on `b`. - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Multiplies `a` by X^k and stores the result on `a`. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + /// Multiplies the selected column of `a` by X^k. + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); - /// Splits b into subrings and copies them them into a. + /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// /// # Panics /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); + fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx); - /// Merges the subrings a into b. + /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// /// # Panics /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); + fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize); } impl VecZnxOps for Module { @@ -118,164 +122,213 @@ impl VecZnxOps for Module { VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } } - fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnx, + col_a: usize, + tmp_bytes: &mut [u8], + ) { #[cfg(debug_assertions)] { - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self, a.cols())); + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } - - let a_size: usize = a.size(); - let b_size: usize = b.sl(); - let a_sl: usize = a.size(); - let b_sl: usize = a.sl(); - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - (0..min_cols).for_each(|i| unsafe { + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.at_mut_ptr(i, 0), - b_size as u64, - b_sl as u64, - a.at_ptr(i, 0), - a_size as u64, - a_sl as u64, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); - }); - - (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_normalize(self, log_base2k, &mut *a_ptr, &*a_ptr, tmp_bytes); + Self::vec_znx_normalize( + self, + log_base2k, + &mut *a_ptr, + col_a, + &*a_ptr, + col_a, + tmp_bytes, + ); } } - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::(self, c, a, b, op); - } - - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) } } - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::(self, c, a, b, op); - } - - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr); + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); } } - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } unsafe { - let b_ptr: *mut VecZnx = b as *mut VecZnx; - Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) } } - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_0( - self.ptr, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_negate, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnx = res as *mut VecZnx; + Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr); + Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_rotate, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_rotate( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_automorphism, - ); - apply_unary_op::(self, b, a, op); + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { - let (n_in, n_out) = (a.n(), b[0].n()); + fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) { + let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); - b[1..].iter().for_each(|bi| { + res[1..].iter().for_each(|bi| { debug_assert_eq!( bi.n(), n_out, @@ -283,19 +336,19 @@ impl VecZnxOps for Module { ) }); - b.iter_mut().enumerate().for_each(|(i, bi)| { + res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, a); - self.vec_znx_rotate(-1, buf, a); + switch_degree(bi, col_res, a, col_a); + self.vec_znx_rotate(-1, buf, 0, a, col_a); } else { - switch_degree(bi, buf); - self.vec_znx_rotate_inplace(-1, buf); + switch_degree(bi, col_res, buf, col_a); + self.vec_znx_rotate_inplace(-1, buf, col_a); } }) } - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { - let (n_in, n_out) = (b.n(), a[0].n()); + fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize) { + let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( n_out < n_in, @@ -310,456 +363,10 @@ impl VecZnxOps for Module { }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(b, ai); - self.vec_znx_rotate_inplace(-1, b); + switch_degree(res, col_res, ai, col_a); + self.vec_znx_rotate_inplace(-1, res, col_res); }); - self.vec_znx_rotate_inplace(a.len() as i64, b); - } -} - -fn ffi_ternary_op_factory( - module_ptr: *const MODULE, - c_size: usize, - c_sl: usize, - a_size: usize, - a_sl: usize, - b_size: usize, - b_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64], &[i64]) { - move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe { - op_fn( - module_ptr, - cv.as_mut_ptr(), - c_size as u64, - c_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - bv.as_ptr(), - b_size as u64, - b_sl as u64, - ) - } -} - -#[cfg(test)] -mod tests { - use crate::internals::znx_post_process_ternary_op; - use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx}; - - use itertools::izip; - use sampling::source::Source; - use std::cmp::min; - - #[test] - fn vec_znx_add() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { - izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai); - }; - test_binary_op::( - &module, - &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b), - op, - ); - } - - #[test] - fn vec_znx_add_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_sub() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { - izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai); - }; - test_binary_op::( - &module, - &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b), - op, - ); - } - - #[test] - fn vec_znx_sub_ab_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_sub_ba_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |bv: &mut [i64], av: &[i64]| { - izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai); - }; - test_binary_op_inplace::( - &module, - &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a), - op, - ); - } - - #[test] - fn vec_znx_negate() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |b: &mut [i64], a: &[i64]| { - izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai); - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a), - op, - ) - } - - #[test] - fn vec_znx_negate_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi); - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_negate_inplace(a), - op, - ) - } - - #[test] - fn vec_znx_rotate() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = 53; - let op = |b: &mut [i64], a: &[i64]| { - assert_eq!(b.len(), a.len()); - b.copy_from_slice(a); - - let mut k_mod2n: i64 = k % (2 * n as i64); - if k_mod2n < 0 { - k_mod2n += 2 * n as i64; - } - let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; - let k_modn: i64 = k_mod2n % (n as i64); - - b.rotate_right(k_modn as usize); - b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); - - if sign == 1 { - b.iter_mut().for_each(|x| *x = -*x); - } - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a), - op, - ) - } - - #[test] - fn vec_znx_rotate_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = 53; - let rot = |a: &mut [i64]| { - let mut k_mod2n: i64 = k % (2 * n as i64); - if k_mod2n < 0 { - k_mod2n += 2 * n as i64; - } - let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; - let k_modn: i64 = k_mod2n % (n as i64); - - a.rotate_right(k_modn as usize); - a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); - - if sign == 1 { - a.iter_mut().for_each(|x| *x = -*x); - } - }; - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a), - rot, - ) - } - - #[test] - fn vec_znx_automorphism() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = -5; - let op = |b: &mut [i64], a: &[i64]| { - assert_eq!(b.len(), a.len()); - unsafe { - vec_znx::vec_znx_automorphism( - module.ptr, - k, - b.as_mut_ptr(), - 1u64, - n as u64, - a.as_ptr(), - 1u64, - n as u64, - ); - } - }; - test_unary_op( - &module, - |b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a), - op, - ) - } - - #[test] - fn vec_znx_automorphism_inplace() { - let n: usize = 8; - let module: Module = Module::::new(n); - let k: i64 = -5; - let op = |a: &mut [i64]| unsafe { - vec_znx::vec_znx_automorphism( - module.ptr, - k, - a.as_mut_ptr(), - 1u64, - n as u64, - a.as_ptr(), - 1u64, - n as u64, - ); - }; - test_unary_op_inplace( - &module, - |a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a), - op, - ) - } - - fn test_binary_op( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 4; - let c_size: usize = 5; - let mut source: Source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - [1usize, 2, 3].iter().for_each(|c_cols| { - let min_ab_cols: usize = min(*a_cols, *b_cols); - let min_cols: usize = min(*c_cols, min_ab_cols); - let min_size: usize = min(c_size, min(a_size, b_size)); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..*b_cols).for_each(|i| { - module.fill_uniform(3, &mut b, i, b_size, &mut source); - }); - - // Allocats c and populates with random values. - let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size); - (0..c_have.cols()).for_each(|i| { - module.fill_uniform(3, &mut c_have, i, c_size, &mut source); - }); - - // Applies the function to test - func_have(&mut c_have, &a, &b); - - let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size); - - // Applies the reference function and expected behavior. - // Adds with the minimum matching columns - (0..min_cols).for_each(|i| { - // Adds with th eminimum matching size - (0..min_size).for_each(|j| { - func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j)); - }); - - if a_size > b_size { - // Copies remaining size of lh if lh.size() > rh.size() - (min_size..a_size).for_each(|j| { - izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai); - if NEGATE { - c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - } else { - // Copies the remaining size of rh if the are greater - (min_size..b_size).for_each(|j| { - izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi); - if NEGATE { - c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - } - }); - - znx_post_process_ternary_op::(&mut c_want, &a, &b); - - assert_eq!(c_have.raw(), c_want.raw()); - }); - }); - }); - } - - fn test_binary_op_inplace( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 5; - let mut source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - let min_cols: usize = min(*b_cols, *a_cols); - let min_size: usize = min(b_size, a_size); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..*b_cols).for_each(|i| { - module.fill_uniform(3, &mut b_have, i, b_size, &mut source); - }); - - let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); - b_want.raw_mut().copy_from_slice(b_have.raw()); - - // Applies the function to test. - func_have(&mut b_have, &a); - - // Applies the reference function and expected behavior. - // Applies with the minimum matching columns - (0..min_cols).for_each(|i| { - // Adds with th eminimum matching size - (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); - if NEGATE { - (min_size..b_size).for_each(|j| { - b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - }); - } - }); - - assert_eq!(b_have.raw(), b_want.raw()); - }); - }); - } - - fn test_unary_op( - module: &Module, - func_have: impl Fn(&mut VecZnx, &VecZnx), - func_want: impl Fn(&mut [i64], &[i64]), - ) { - let a_size: usize = 3; - let b_size: usize = 5; - let mut source = Source::new([0u8; 32]); - - [1usize, 2, 3].iter().for_each(|a_cols| { - [1usize, 2, 3].iter().for_each(|b_cols| { - let min_cols: usize = min(*b_cols, *a_cols); - let min_size: usize = min(b_size, a_size); - - // Allocats a and populates with random values. - let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..a.cols()).for_each(|i| { - module.fill_uniform(3, &mut a, i, a_size, &mut source); - }); - - // Allocats b and populates with random values. - let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); - (0..b_have.cols()).for_each(|i| { - module.fill_uniform(3, &mut b_have, i, b_size, &mut source); - }); - - let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); - - // Applies the function to test. - func_have(&mut b_have, &a); - - // Applies the reference function and expected behavior. - // Applies on the minimum matching columns - (0..min_cols).for_each(|i| { - // Applies on the minimum matching size - (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); - - // Zeroes the unmatching size - (min_size..b_size).for_each(|j| { - b_want.zero_at(i, j); - }) - }); - - // Zeroes the unmatching columns - (min_cols..*b_cols).for_each(|i| { - (0..b_size).for_each(|j| { - b_want.zero_at(i, j); - }) - }); - - assert_eq!(b_have.raw(), b_want.raw()); - }); - }); - } - - fn test_unary_op_inplace(module: &Module, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) { - let a_size: usize = 3; - let mut source = Source::new([0u8; 32]); - [1usize, 2, 3].iter().for_each(|a_cols| { - let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size); - (0..*a_cols).for_each(|i| { - module.fill_uniform(3, &mut a_have, i, a_size, &mut source); - }); - - // Allocats a and populates with random values. - let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size); - a_have.raw_mut().copy_from_slice(a_want.raw()); - - // Applies the function to test. - func_have(&mut a_have); - - // Applies the reference function and expected behavior. - // Applies on the minimum matching columns - (0..*a_cols).for_each(|i| { - // Applies on the minimum matching size - (0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j))); - }); - - assert_eq!(a_have.raw(), a_want.raw()); - }); + self.vec_znx_rotate_inplace(a.len() as i64, res, col_res); } }