From bd933c0e94ef83875703e23987560a14a7d73d15 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 15:53:26 +0200 Subject: [PATCH] Added VecZnxBig ops --- base2k/examples/rlwe_encrypt.rs | 2 +- base2k/src/commons.rs | 104 +--------- base2k/src/internals.rs | 192 ++++++++++++++++++ base2k/src/lib.rs | 3 + base2k/src/vec_znx_big.rs | 269 +------------------------ base2k/src/vec_znx_big_ops.rs | 339 ++++++++++++++++++++++++++++++++ base2k/src/vec_znx_ops.rs | 61 +----- 7 files changed, 549 insertions(+), 421 deletions(-) create mode 100644 base2k/src/internals.rs create mode 100644 base2k/src/vec_znx_big_ops.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 3661f0d..8a5d09f 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -59,7 +59,7 @@ fn main() { m.normalize(log_base2k, &mut carry); // buf_big <- m - buf_big - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); + module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); println!("{:?}", buf_big.raw()); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index cfae556..969897d 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,6 +1,6 @@ use crate::{Backend, Module, assert_alignement, cast_mut}; use itertools::izip; -use std::cmp::{max, min}; +use std::cmp::min; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. @@ -243,105 +243,3 @@ where .for_each(|(x_in, x_out)| *x_out = *x_in); }); } - -pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) -where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_ne!(a.as_ptr(), b.as_ptr()); - assert_ne!(b.as_ptr(), c.as_ptr()); - assert_ne!(a.as_ptr(), c.as_ptr()); - } - - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - - let min_ab_cols: usize = min(a_cols, b_cols); - let max_ab_cols: usize = max(a_cols, b_cols); - - // Copies shared shared cols between (c, max(a, b)) - if a_cols != b_cols { - let mut x: &T = a; - if a_cols < b_cols { - x = b; - } - - let min_size = min(c.size(), x.size()); - (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { - (0..min_size).for_each(|j| { - c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j)); - if NEGATE { - c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); - } - }); - (min_size..c.size()).for_each(|j| { - c.zero_at(i, j); - }); - }); - } - - // Zeroes the cols of c > max(a, b). - if c_cols > max_ab_cols { - (max_ab_cols..c_cols).for_each(|i| { - (0..c.size()).for_each(|j| { - c.zero_at(i, j); - }) - }); - } -} - -#[inline(always)] -pub fn apply_binary_op( - module: &Module, - c: &mut T, - a: &T, - b: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - -#[inline(always)] -pub fn apply_unary_op( - module: &Module, - b: &mut T, - a: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs new file mode 100644 index 0000000..d7b08dc --- /dev/null +++ b/base2k/src/internals.rs @@ -0,0 +1,192 @@ +use std::cmp::{max, min}; + +use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; + +pub(crate) fn znx_post_process_ternary_op(c: &mut C, a: &A, b: &B) +where + C: ZnxBasics + ZnxLayout, + A: ZnxBasics + ZnxLayout, + B: ZnxBasics + ZnxLayout, + C::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_ne!(a.as_ptr(), b.as_ptr()); + assert_ne!(b.as_ptr(), c.as_ptr()); + assert_ne!(a.as_ptr(), c.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let max_ab_cols: usize = max(a_cols, b_cols); + + // Copies shared shared cols between (c, max(a, b)) + if a_cols != b_cols { + if a_cols > b_cols { + let min_size = min(c.size(), a.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } else { + let min_size = min(c.size(), b.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } + } + + // Zeroes the cols of c > max(a, b). + if c_cols > max_ab_cols { + (max_ab_cols..c_cols).for_each(|i| { + (0..c.size()).for_each(|j| { + c.zero_at(i, j); + }) + }); + } +} + +#[inline(always)] +pub fn apply_binary_op( + module: &Module, + c: &mut C, + a: &A, + b: &B, + op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]), +) where + BE: Backend, + C: ZnxBasics + ZnxLayout, + A: ZnxBasics + ZnxLayout, + B: ZnxBasics + ZnxLayout, + C::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn apply_unary_op( + module: &Module, + b: &mut T, + a: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} + +pub fn ffi_ternary_op_factory( + module_ptr: *const MODULE, + c_size: usize, + c_sl: usize, + a_size: usize, + a_sl: usize, + b_size: usize, + b_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T], &[T]) { + move |cv: &mut [T], av: &[T], bv: &[T]| unsafe { + op_fn( + module_ptr, + cv.as_mut_ptr(), + c_size as u64, + c_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + bv.as_ptr(), + b_size as u64, + b_sl as u64, + ) + } +} + +pub fn ffi_binary_op_factory_type_0( + module_ptr: *const MODULE, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T]) { + move |bv: &mut [T], av: &[T]| unsafe { + op_fn( + module_ptr, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +pub fn ffi_binary_op_factory_type_1( + module_ptr: *const MODULE, + k: i64, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64), +) -> impl Fn(&mut [T], &[T]) { + move |bv: &mut [T], av: &[T]| unsafe { + op_fn( + module_ptr, + k, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 3c48319..2a9a899 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,6 +3,7 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; +mod internals; pub mod mat_znx_dft; pub mod module; pub mod sampling; @@ -10,6 +11,7 @@ pub mod scalar_znx_dft; pub mod stats; pub mod vec_znx; pub mod vec_znx_big; +pub mod vec_znx_big_ops; pub mod vec_znx_dft; pub mod vec_znx_ops; @@ -23,6 +25,7 @@ pub use scalar_znx_dft::*; pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; +pub use vec_znx_big_ops::*; pub use vec_znx_dft::*; pub use vec_znx_ops::*; diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d54d72d..67b75a2 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::ffi::vec_znx_big; +use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use std::marker::PhantomData; pub struct VecZnxBig { @@ -10,6 +10,9 @@ pub struct VecZnxBig { pub size: usize, pub _marker: PhantomData, } + +impl ZnxBasics for VecZnxBig {} + impl ZnxBase for VecZnxBig { type Scalar = u8; @@ -112,265 +115,3 @@ impl VecZnxBig { (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } - -pub trait VecZnxBigOps { - /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; - - /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - - /// Normalizes `a` and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); - - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - - /// Normalize `a`, taking into account column interleaving and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `a_range_begin`: column to start. - /// * `a_range_end`: column to end. - /// * `a_range_step`: column step size. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ); - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); -} - -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, size) - } - - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, size, bytes) - } - - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) - } - - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) - } - - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - ) - } - } - - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - ) - } - } - - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.poly_count() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.as_ptr(), - a.poly_count() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_normalize_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( - self.ptr, - log_base2k as u64, - b.as_mut_ptr(), - b.size() as u64, - b.n() as u64, - a.ptr as *mut vec_znx_big_t, - a.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_range_normalize_base2k( - self.ptr, - log_base2k as u64, - res.as_mut_ptr(), - res.size() as u64, - res.n() as u64, - a.ptr as *mut vec_znx_big_t, - a_range_begin as u64, - a_range_xend as u64, - a_range_step as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } -} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs new file mode 100644 index 0000000..530cb54 --- /dev/null +++ b/base2k/src/vec_znx_big_ops.rs @@ -0,0 +1,339 @@ +use crate::ffi::vec_znx_big::vec_znx_big_t; +use crate::ffi::{vec_znx, vec_znx_big}; +use crate::internals::{apply_binary_op, ffi_ternary_op_factory}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; + +pub trait VecZnxBigOps { + /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; + + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials.. + /// * `size`: the number of polynomials per column. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; + + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// Behavior: the backing array is only borrowed. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials.. + /// * `size`: the number of polynomials per column. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; + + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + + /// Subtracts `a` to `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Subtracts `b` to `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Subtracts `b` to `a` and stores the result on `c`. + fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + + /// Subtracts `a` to `b` and stores the result on `b`. + fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Subtracts `b` to `a` and stores the result on `c`. + fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx); + + /// Subtracts `b` to `a` and stores the result on `b`. + fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; + + /// Normalizes `a` and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; + + /// Normalize `a`, taking into account column interleaving and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `a_range_begin`: column to start. + /// * `a_range_end`: column to end. + /// * `a_range_step`: column step size. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. + fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ); + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); +} + +impl VecZnxBigOps for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { + VecZnxBig::new(self, cols, size) + } + + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, size, bytes) + } + + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) + } + + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + VecZnxBig::bytes_of(self, cols, size) + } + + fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + apply_binary_op::, VecZnxBig, VecZnxBig, false>(self, c, a, b, op); + } + + fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnxBig, VecZnxBig, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnxBig, VecZnx, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + apply_binary_op::, VecZnx, VecZnxBig, true>(self, c, a, b, op); + } + + fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + apply_binary_op::, VecZnx, VecZnxBig, false>(self, c, a, b, op); + } + + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; + Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } + } + + fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + debug_assert!( + tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_big_normalize_tmp_bytes(self) + ); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_big::vec_znx_big_normalize_base2k( + self.ptr, + log_base2k as u64, + b.as_mut_ptr(), + b.size() as u64, + b.n() as u64, + a.ptr as *mut vec_znx_big_t, + a.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } + } + + fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + res: &mut VecZnx, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ) { + debug_assert!( + tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + ); + #[cfg(debug_assertions)] + { + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_big::vec_znx_big_range_normalize_base2k( + self.ptr, + log_base2k as u64, + res.as_mut_ptr(), + res.size() as u64, + res.n() as u64, + a.ptr as *mut vec_znx_big_t, + a_range_begin as u64, + a_range_xend as u64, + a_range_step as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + unsafe { + vec_znx_big::vec_znx_big_automorphism( + self.ptr, + gal_el, + b.ptr as *mut vec_znx_big_t, + b.poly_count() as u64, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + ); + } + } + + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { + unsafe { + vec_znx_big::vec_znx_big_automorphism( + self.ptr, + gal_el, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + a.ptr as *mut vec_znx_big_t, + a.poly_count() as u64, + ); + } + } +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 573e5b1..f67b6b0 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,7 +1,7 @@ use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout}; -use std::cmp::min; +use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -125,7 +125,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_add, ); - apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -146,7 +146,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_sub, ); - apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -298,56 +298,11 @@ fn ffi_ternary_op_factory( } } -fn ffi_binary_op_factory_type_0( - module_ptr: *const MODULE, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64]) { - move |bv: &mut [i64], av: &[i64]| unsafe { - op_fn( - module_ptr, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - -fn ffi_binary_op_factory_type_1( - module_ptr: *const MODULE, - k: i64, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64), -) -> impl Fn(&mut [i64], &[i64]) { - move |bv: &mut [i64], av: &[i64]| unsafe { - op_fn( - module_ptr, - k, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - #[cfg(test)] mod tests { - use crate::{ - Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx, - znx_post_process_ternary_op, - }; + use crate::internals::znx_post_process_ternary_op; + use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx}; + use itertools::izip; use sampling::source::Source; use std::cmp::min; @@ -623,7 +578,7 @@ mod tests { } }); - znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b); + znx_post_process_ternary_op::(&mut c_want, &a, &b); assert_eq!(c_have.raw(), c_want.raw()); });