Added VecZnxBig<FFT64> ops

This commit is contained in:
Jean-Philippe Bossuat
2025-04-29 15:53:26 +02:00
parent 3ee69866bd
commit bd933c0e94
7 changed files with 549 additions and 421 deletions

View File

@@ -59,7 +59,7 @@ fn main() {
m.normalize(log_base2k, &mut carry); m.normalize(log_base2k, &mut carry);
// buf_big <- m - buf_big // buf_big <- m - buf_big
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m);
println!("{:?}", buf_big.raw()); println!("{:?}", buf_big.raw());

View File

@@ -1,6 +1,6 @@
use crate::{Backend, Module, assert_alignement, cast_mut}; use crate::{Backend, Module, assert_alignement, cast_mut};
use itertools::izip; use itertools::izip;
use std::cmp::{max, min}; use std::cmp::min;
pub trait ZnxInfos { pub trait ZnxInfos {
/// Returns the ring degree of the polynomials. /// Returns the ring degree of the polynomials.
@@ -243,105 +243,3 @@ where
.for_each(|(x_in, x_out)| *x_out = *x_in); .for_each(|(x_in, x_out)| *x_out = *x_in);
}); });
} }
pub fn znx_post_process_ternary_op<T: ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_ne!(a.as_ptr(), b.as_ptr());
assert_ne!(b.as_ptr(), c.as_ptr());
assert_ne!(a.as_ptr(), c.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let max_ab_cols: usize = max(a_cols, b_cols);
// Copies shared shared cols between (c, max(a, b))
if a_cols != b_cols {
let mut x: &T = a;
if a_cols < b_cols {
x = b;
}
let min_size = min(c.size(), x.size());
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
(0..min_size).for_each(|j| {
c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j));
if NEGATE {
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
(min_size..c.size()).for_each(|j| {
c.zero_at(i, j);
});
});
}
// Zeroes the cols of c > max(a, b).
if c_cols > max_ab_cols {
(max_ab_cols..c_cols).for_each(|i| {
(0..c.size()).for_each(|j| {
c.zero_at(i, j);
})
});
}
}
#[inline(always)]
pub fn apply_binary_op<B: Backend, T: ZnxBasics + ZnxLayout, const NEGATE: bool>(
module: &Module<B>,
c: &mut T,
a: &T,
b: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(c.n(), module.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let min_cols: usize = min(c_cols, min_ab_cols);
// Applies over shared cols between (a, b, c)
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
znx_post_process_ternary_op::<T, NEGATE>(c, a, b);
}
}
#[inline(always)]
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
module: &Module<B>,
b: &mut T,
a: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}

192
base2k/src/internals.rs Normal file
View File

@@ -0,0 +1,192 @@
use std::cmp::{max, min};
use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE};
pub(crate) fn znx_post_process_ternary_op<C, A, B, const NEGATE: bool>(c: &mut C, a: &A, b: &B)
where
C: ZnxBasics + ZnxLayout,
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
C::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_ne!(a.as_ptr(), b.as_ptr());
assert_ne!(b.as_ptr(), c.as_ptr());
assert_ne!(a.as_ptr(), c.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let max_ab_cols: usize = max(a_cols, b_cols);
// Copies shared shared cols between (c, max(a, b))
if a_cols != b_cols {
if a_cols > b_cols {
let min_size = min(c.size(), a.size());
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
(0..min_size).for_each(|j| {
c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j));
if NEGATE {
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
(min_size..c.size()).for_each(|j| {
c.zero_at(i, j);
});
});
} else {
let min_size = min(c.size(), b.size());
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
(0..min_size).for_each(|j| {
c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j));
if NEGATE {
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
(min_size..c.size()).for_each(|j| {
c.zero_at(i, j);
});
});
}
}
// Zeroes the cols of c > max(a, b).
if c_cols > max_ab_cols {
(max_ab_cols..c_cols).for_each(|i| {
(0..c.size()).for_each(|j| {
c.zero_at(i, j);
})
});
}
}
#[inline(always)]
pub fn apply_binary_op<BE, C, A, B, const NEGATE: bool>(
module: &Module<BE>,
c: &mut C,
a: &A,
b: &B,
op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]),
) where
BE: Backend,
C: ZnxBasics + ZnxLayout,
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
C::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(c.n(), module.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let min_cols: usize = min(c_cols, min_ab_cols);
// Applies over shared cols between (a, b, c)
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
znx_post_process_ternary_op::<C, A, B, NEGATE>(c, a, b);
}
}
#[inline(always)]
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
module: &Module<B>,
b: &mut T,
a: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}
pub fn ffi_ternary_op_factory<T>(
module_ptr: *const MODULE,
c_size: usize,
c_sl: usize,
a_size: usize,
a_sl: usize,
b_size: usize,
b_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T], &[T]) {
move |cv: &mut [T], av: &[T], bv: &[T]| unsafe {
op_fn(
module_ptr,
cv.as_mut_ptr(),
c_size as u64,
c_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
bv.as_ptr(),
b_size as u64,
b_sl as u64,
)
}
}
pub fn ffi_binary_op_factory_type_0<T>(
module_ptr: *const MODULE,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T]) {
move |bv: &mut [T], av: &[T]| unsafe {
op_fn(
module_ptr,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
pub fn ffi_binary_op_factory_type_1<T>(
module_ptr: *const MODULE,
k: i64,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T]) {
move |bv: &mut [T], av: &[T]| unsafe {
op_fn(
module_ptr,
k,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}

View File

@@ -3,6 +3,7 @@ pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports // Other modules and exports
pub mod ffi; pub mod ffi;
mod internals;
pub mod mat_znx_dft; pub mod mat_znx_dft;
pub mod module; pub mod module;
pub mod sampling; pub mod sampling;
@@ -10,6 +11,7 @@ pub mod scalar_znx_dft;
pub mod stats; pub mod stats;
pub mod vec_znx; pub mod vec_znx;
pub mod vec_znx_big; pub mod vec_znx_big;
pub mod vec_znx_big_ops;
pub mod vec_znx_dft; pub mod vec_znx_dft;
pub mod vec_znx_ops; pub mod vec_znx_ops;
@@ -23,6 +25,7 @@ pub use scalar_znx_dft::*;
pub use stats::*; pub use stats::*;
pub use vec_znx::*; pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_big_ops::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vec_znx_ops::*; pub use vec_znx_ops::*;

View File

@@ -1,5 +1,5 @@
use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; use crate::ffi::vec_znx_big;
use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
use std::marker::PhantomData; use std::marker::PhantomData;
pub struct VecZnxBig<B: Backend> { pub struct VecZnxBig<B: Backend> {
@@ -10,6 +10,9 @@ pub struct VecZnxBig<B: Backend> {
pub size: usize, pub size: usize,
pub _marker: PhantomData<B>, pub _marker: PhantomData<B>,
} }
impl ZnxBasics for VecZnxBig<FFT64> {}
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> { impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
type Scalar = u8; type Scalar = u8;
@@ -112,265 +115,3 @@ impl VecZnxBig<FFT64> {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
} }
} }
pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
/// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
/// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
/// Normalizes `a` and stores the result on `b`.
///
/// # Arguments
///
/// * `log_base2k`: normalization basis.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<B>, tmp_bytes: &mut [u8]);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k].
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
/// Normalize `a`, taking into account column interleaving and stores the result on `b`.
///
/// # Arguments
///
/// * `log_base2k`: normalization basis.
/// * `a_range_begin`: column to start.
/// * `a_range_end`: column to end.
/// * `a_range_step`: column step size.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes].
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &VecZnxBig<B>,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>);
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
VecZnxBig::new(self, cols, size)
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, cols, size, bytes)
}
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, cols, size)
}
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.as_ptr(),
a.poly_count() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
)
}
}
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
c.ptr as *mut vec_znx_big_t,
c.poly_count() as u64,
a.as_ptr(),
a.poly_count() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
)
}
}
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
c.ptr as *mut vec_znx_big_t,
c.poly_count() as u64,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.as_ptr(),
a.poly_count() as u64,
a.n() as u64,
)
}
}
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.as_ptr(),
a.poly_count() as u64,
a.n() as u64,
)
}
}
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_big_normalize_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_normalize_base2k(
self.ptr,
log_base2k as u64,
b.as_mut_ptr(),
b.size() as u64,
b.n() as u64,
a.ptr as *mut vec_znx_big_t,
a.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
res: &mut VecZnx,
a: &VecZnxBig<FFT64>,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k(
self.ptr,
log_base2k as u64,
res.as_mut_ptr(),
res.size() as u64,
res.n() as u64,
a.ptr as *mut vec_znx_big_t,
a_range_begin as u64,
a_range_xend as u64,
a_range_step as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
);
}
}
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
);
}
}
}

View File

@@ -0,0 +1,339 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::{vec_znx, vec_znx_big};
use crate::internals::{apply_binary_op, ffi_ternary_op_factory};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnxBig<B>);
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnxBig<B>);
/// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
/// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
/// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
/// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
/// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnx);
/// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
/// Normalizes `a` and stores the result on `b`.
///
/// # Arguments
///
/// * `log_base2k`: normalization basis.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<B>, tmp_bytes: &mut [u8]);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k].
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
/// Normalize `a`, taking into account column interleaving and stores the result on `b`.
///
/// # Arguments
///
/// * `log_base2k`: normalization basis.
/// * `a_range_begin`: column to start.
/// * `a_range_end`: column to end.
/// * `a_range_step`: column step size.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes].
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &VecZnxBig<B>,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>);
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
VecZnxBig::new(self, cols, size)
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, cols, size, bytes)
}
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, cols, size)
}
fn vec_znx_big_add(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_add,
);
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnxBig<FFT64>, false>(self, c, a, b, op);
}
fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_big_sub(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_sub,
);
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnxBig<FFT64>, true>(self, c, a, b, op);
}
fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a);
}
}
fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnx) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_sub,
);
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnx, true>(self, c, a, b, op);
}
fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a);
}
}
fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_sub,
);
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnx, VecZnxBig<FFT64>, true>(self, c, a, b, op);
}
fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_add,
);
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnx, VecZnxBig<FFT64>, false>(self, c, a, b, op);
}
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_big_normalize_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_normalize_base2k(
self.ptr,
log_base2k as u64,
b.as_mut_ptr(),
b.size() as u64,
b.n() as u64,
a.ptr as *mut vec_znx_big_t,
a.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
res: &mut VecZnx,
a: &VecZnxBig<FFT64>,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k(
self.ptr,
log_base2k as u64,
res.as_mut_ptr(),
res.size() as u64,
res.n() as u64,
a.ptr as *mut vec_znx_big_t,
a_range_begin as u64,
a_range_xend as u64,
a_range_step as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
);
}
}
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig<FFT64>) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
a.ptr as *mut vec_znx_big_t,
a.poly_count() as u64,
);
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::ffi::module::MODULE; use crate::ffi::module::MODULE;
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout}; use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1};
use std::cmp::min; use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree};
pub trait VecZnxOps { pub trait VecZnxOps {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
@@ -125,7 +125,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
b.sl(), b.sl(),
vec_znx::vec_znx_add, vec_znx::vec_znx_add,
); );
apply_binary_op::<B, VecZnx, false>(self, c, a, b, op); apply_binary_op::<B, VecZnx, VecZnx, VecZnx, false>(self, c, a, b, op);
} }
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
@@ -146,7 +146,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
b.sl(), b.sl(),
vec_znx::vec_znx_sub, vec_znx::vec_znx_sub,
); );
apply_binary_op::<B, VecZnx, true>(self, c, a, b, op); apply_binary_op::<B, VecZnx, VecZnx, VecZnx, true>(self, c, a, b, op);
} }
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
@@ -298,56 +298,11 @@ fn ffi_ternary_op_factory(
} }
} }
fn ffi_binary_op_factory_type_0(
module_ptr: *const MODULE,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64),
) -> impl Fn(&mut [i64], &[i64]) {
move |bv: &mut [i64], av: &[i64]| unsafe {
op_fn(
module_ptr,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
fn ffi_binary_op_factory_type_1(
module_ptr: *const MODULE,
k: i64,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64),
) -> impl Fn(&mut [i64], &[i64]) {
move |bv: &mut [i64], av: &[i64]| unsafe {
op_fn(
module_ptr,
k,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::internals::znx_post_process_ternary_op;
Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx, use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx};
znx_post_process_ternary_op,
};
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
use std::cmp::min; use std::cmp::min;
@@ -623,7 +578,7 @@ mod tests {
} }
}); });
znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b); znx_post_process_ternary_op::<VecZnx, VecZnx, VecZnx, NEGATE>(&mut c_want, &a, &b);
assert_eq!(c_have.raw(), c_want.raw()); assert_eq!(c_have.raw(), c_want.raw());
}); });