mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added VecZnxBig<FFT64> ops
This commit is contained in:
@@ -59,7 +59,7 @@ fn main() {
|
|||||||
m.normalize(log_base2k, &mut carry);
|
m.normalize(log_base2k, &mut carry);
|
||||||
|
|
||||||
// buf_big <- m - buf_big
|
// buf_big <- m - buf_big
|
||||||
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m);
|
module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m);
|
||||||
|
|
||||||
println!("{:?}", buf_big.raw());
|
println!("{:?}", buf_big.raw());
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::{Backend, Module, assert_alignement, cast_mut};
|
use crate::{Backend, Module, assert_alignement, cast_mut};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use std::cmp::{max, min};
|
use std::cmp::min;
|
||||||
|
|
||||||
pub trait ZnxInfos {
|
pub trait ZnxInfos {
|
||||||
/// Returns the ring degree of the polynomials.
|
/// Returns the ring degree of the polynomials.
|
||||||
@@ -243,105 +243,3 @@ where
|
|||||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn znx_post_process_ternary_op<T: ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
|
|
||||||
where
|
|
||||||
<T as ZnxLayout>::Scalar: IntegerType,
|
|
||||||
{
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
|
||||||
assert_ne!(b.as_ptr(), c.as_ptr());
|
|
||||||
assert_ne!(a.as_ptr(), c.as_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
let a_cols: usize = a.cols();
|
|
||||||
let b_cols: usize = b.cols();
|
|
||||||
let c_cols: usize = c.cols();
|
|
||||||
|
|
||||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
|
||||||
let max_ab_cols: usize = max(a_cols, b_cols);
|
|
||||||
|
|
||||||
// Copies shared shared cols between (c, max(a, b))
|
|
||||||
if a_cols != b_cols {
|
|
||||||
let mut x: &T = a;
|
|
||||||
if a_cols < b_cols {
|
|
||||||
x = b;
|
|
||||||
}
|
|
||||||
|
|
||||||
let min_size = min(c.size(), x.size());
|
|
||||||
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
|
||||||
(0..min_size).for_each(|j| {
|
|
||||||
c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j));
|
|
||||||
if NEGATE {
|
|
||||||
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
(min_size..c.size()).for_each(|j| {
|
|
||||||
c.zero_at(i, j);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zeroes the cols of c > max(a, b).
|
|
||||||
if c_cols > max_ab_cols {
|
|
||||||
(max_ab_cols..c_cols).for_each(|i| {
|
|
||||||
(0..c.size()).for_each(|j| {
|
|
||||||
c.zero_at(i, j);
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn apply_binary_op<B: Backend, T: ZnxBasics + ZnxLayout, const NEGATE: bool>(
|
|
||||||
module: &Module<B>,
|
|
||||||
c: &mut T,
|
|
||||||
a: &T,
|
|
||||||
b: &T,
|
|
||||||
op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]),
|
|
||||||
) where
|
|
||||||
<T as ZnxLayout>::Scalar: IntegerType,
|
|
||||||
{
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.n(), module.n());
|
|
||||||
assert_eq!(b.n(), module.n());
|
|
||||||
assert_eq!(c.n(), module.n());
|
|
||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
|
||||||
}
|
|
||||||
let a_cols: usize = a.cols();
|
|
||||||
let b_cols: usize = b.cols();
|
|
||||||
let c_cols: usize = c.cols();
|
|
||||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
|
||||||
let min_cols: usize = min(c_cols, min_ab_cols);
|
|
||||||
// Applies over shared cols between (a, b, c)
|
|
||||||
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
|
|
||||||
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
|
|
||||||
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
|
|
||||||
znx_post_process_ternary_op::<T, NEGATE>(c, a, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
|
|
||||||
module: &Module<B>,
|
|
||||||
b: &mut T,
|
|
||||||
a: &T,
|
|
||||||
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
|
|
||||||
) where
|
|
||||||
<T as ZnxLayout>::Scalar: IntegerType,
|
|
||||||
{
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.n(), module.n());
|
|
||||||
assert_eq!(b.n(), module.n());
|
|
||||||
}
|
|
||||||
let a_cols: usize = a.cols();
|
|
||||||
let b_cols: usize = b.cols();
|
|
||||||
let min_cols: usize = min(a_cols, b_cols);
|
|
||||||
// Applies over the shared cols between (a, b)
|
|
||||||
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
|
|
||||||
// Zeroes the remaining cols of b.
|
|
||||||
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
|
|
||||||
}
|
|
||||||
|
|||||||
192
base2k/src/internals.rs
Normal file
192
base2k/src/internals.rs
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
use std::cmp::{max, min};
|
||||||
|
|
||||||
|
use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE};
|
||||||
|
|
||||||
|
pub(crate) fn znx_post_process_ternary_op<C, A, B, const NEGATE: bool>(c: &mut C, a: &A, b: &B)
|
||||||
|
where
|
||||||
|
C: ZnxBasics + ZnxLayout,
|
||||||
|
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||||
|
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||||
|
C::Scalar: IntegerType,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
|
assert_ne!(b.as_ptr(), c.as_ptr());
|
||||||
|
assert_ne!(a.as_ptr(), c.as_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
let a_cols: usize = a.cols();
|
||||||
|
let b_cols: usize = b.cols();
|
||||||
|
let c_cols: usize = c.cols();
|
||||||
|
|
||||||
|
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||||
|
let max_ab_cols: usize = max(a_cols, b_cols);
|
||||||
|
|
||||||
|
// Copies shared shared cols between (c, max(a, b))
|
||||||
|
if a_cols != b_cols {
|
||||||
|
if a_cols > b_cols {
|
||||||
|
let min_size = min(c.size(), a.size());
|
||||||
|
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||||
|
(0..min_size).for_each(|j| {
|
||||||
|
c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j));
|
||||||
|
if NEGATE {
|
||||||
|
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
(min_size..c.size()).for_each(|j| {
|
||||||
|
c.zero_at(i, j);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
let min_size = min(c.size(), b.size());
|
||||||
|
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||||
|
(0..min_size).for_each(|j| {
|
||||||
|
c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j));
|
||||||
|
if NEGATE {
|
||||||
|
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
(min_size..c.size()).for_each(|j| {
|
||||||
|
c.zero_at(i, j);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zeroes the cols of c > max(a, b).
|
||||||
|
if c_cols > max_ab_cols {
|
||||||
|
(max_ab_cols..c_cols).for_each(|i| {
|
||||||
|
(0..c.size()).for_each(|j| {
|
||||||
|
c.zero_at(i, j);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_binary_op<BE, C, A, B, const NEGATE: bool>(
|
||||||
|
module: &Module<BE>,
|
||||||
|
c: &mut C,
|
||||||
|
a: &A,
|
||||||
|
b: &B,
|
||||||
|
op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]),
|
||||||
|
) where
|
||||||
|
BE: Backend,
|
||||||
|
C: ZnxBasics + ZnxLayout,
|
||||||
|
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||||
|
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||||
|
C::Scalar: IntegerType,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), module.n());
|
||||||
|
assert_eq!(b.n(), module.n());
|
||||||
|
assert_eq!(c.n(), module.n());
|
||||||
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
|
}
|
||||||
|
let a_cols: usize = a.cols();
|
||||||
|
let b_cols: usize = b.cols();
|
||||||
|
let c_cols: usize = c.cols();
|
||||||
|
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||||
|
let min_cols: usize = min(c_cols, min_ab_cols);
|
||||||
|
// Applies over shared cols between (a, b, c)
|
||||||
|
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
|
||||||
|
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
|
||||||
|
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
|
||||||
|
znx_post_process_ternary_op::<C, A, B, NEGATE>(c, a, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
|
||||||
|
module: &Module<B>,
|
||||||
|
b: &mut T,
|
||||||
|
a: &T,
|
||||||
|
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
|
||||||
|
) where
|
||||||
|
<T as ZnxLayout>::Scalar: IntegerType,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), module.n());
|
||||||
|
assert_eq!(b.n(), module.n());
|
||||||
|
}
|
||||||
|
let a_cols: usize = a.cols();
|
||||||
|
let b_cols: usize = b.cols();
|
||||||
|
let min_cols: usize = min(a_cols, b_cols);
|
||||||
|
// Applies over the shared cols between (a, b)
|
||||||
|
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
|
||||||
|
// Zeroes the remaining cols of b.
|
||||||
|
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ffi_ternary_op_factory<T>(
|
||||||
|
module_ptr: *const MODULE,
|
||||||
|
c_size: usize,
|
||||||
|
c_sl: usize,
|
||||||
|
a_size: usize,
|
||||||
|
a_sl: usize,
|
||||||
|
b_size: usize,
|
||||||
|
b_sl: usize,
|
||||||
|
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64),
|
||||||
|
) -> impl Fn(&mut [T], &[T], &[T]) {
|
||||||
|
move |cv: &mut [T], av: &[T], bv: &[T]| unsafe {
|
||||||
|
op_fn(
|
||||||
|
module_ptr,
|
||||||
|
cv.as_mut_ptr(),
|
||||||
|
c_size as u64,
|
||||||
|
c_sl as u64,
|
||||||
|
av.as_ptr(),
|
||||||
|
a_size as u64,
|
||||||
|
a_sl as u64,
|
||||||
|
bv.as_ptr(),
|
||||||
|
b_size as u64,
|
||||||
|
b_sl as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ffi_binary_op_factory_type_0<T>(
|
||||||
|
module_ptr: *const MODULE,
|
||||||
|
b_size: usize,
|
||||||
|
b_sl: usize,
|
||||||
|
a_size: usize,
|
||||||
|
a_sl: usize,
|
||||||
|
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64),
|
||||||
|
) -> impl Fn(&mut [T], &[T]) {
|
||||||
|
move |bv: &mut [T], av: &[T]| unsafe {
|
||||||
|
op_fn(
|
||||||
|
module_ptr,
|
||||||
|
bv.as_mut_ptr(),
|
||||||
|
b_size as u64,
|
||||||
|
b_sl as u64,
|
||||||
|
av.as_ptr(),
|
||||||
|
a_size as u64,
|
||||||
|
a_sl as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ffi_binary_op_factory_type_1<T>(
|
||||||
|
module_ptr: *const MODULE,
|
||||||
|
k: i64,
|
||||||
|
b_size: usize,
|
||||||
|
b_sl: usize,
|
||||||
|
a_size: usize,
|
||||||
|
a_sl: usize,
|
||||||
|
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64),
|
||||||
|
) -> impl Fn(&mut [T], &[T]) {
|
||||||
|
move |bv: &mut [T], av: &[T]| unsafe {
|
||||||
|
op_fn(
|
||||||
|
module_ptr,
|
||||||
|
k,
|
||||||
|
bv.as_mut_ptr(),
|
||||||
|
b_size as u64,
|
||||||
|
b_sl as u64,
|
||||||
|
av.as_ptr(),
|
||||||
|
a_size as u64,
|
||||||
|
a_sl as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ pub mod encoding;
|
|||||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||||
// Other modules and exports
|
// Other modules and exports
|
||||||
pub mod ffi;
|
pub mod ffi;
|
||||||
|
mod internals;
|
||||||
pub mod mat_znx_dft;
|
pub mod mat_znx_dft;
|
||||||
pub mod module;
|
pub mod module;
|
||||||
pub mod sampling;
|
pub mod sampling;
|
||||||
@@ -10,6 +11,7 @@ pub mod scalar_znx_dft;
|
|||||||
pub mod stats;
|
pub mod stats;
|
||||||
pub mod vec_znx;
|
pub mod vec_znx;
|
||||||
pub mod vec_znx_big;
|
pub mod vec_znx_big;
|
||||||
|
pub mod vec_znx_big_ops;
|
||||||
pub mod vec_znx_dft;
|
pub mod vec_znx_dft;
|
||||||
pub mod vec_znx_ops;
|
pub mod vec_znx_ops;
|
||||||
|
|
||||||
@@ -23,6 +25,7 @@ pub use scalar_znx_dft::*;
|
|||||||
pub use stats::*;
|
pub use stats::*;
|
||||||
pub use vec_znx::*;
|
pub use vec_znx::*;
|
||||||
pub use vec_znx_big::*;
|
pub use vec_znx_big::*;
|
||||||
|
pub use vec_znx_big_ops::*;
|
||||||
pub use vec_znx_dft::*;
|
pub use vec_znx_dft::*;
|
||||||
pub use vec_znx_ops::*;
|
pub use vec_znx_ops::*;
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
use crate::ffi::vec_znx_big;
|
||||||
use crate::{Backend, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
|
use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
pub struct VecZnxBig<B: Backend> {
|
pub struct VecZnxBig<B: Backend> {
|
||||||
@@ -10,6 +10,9 @@ pub struct VecZnxBig<B: Backend> {
|
|||||||
pub size: usize,
|
pub size: usize,
|
||||||
pub _marker: PhantomData<B>,
|
pub _marker: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ZnxBasics for VecZnxBig<FFT64> {}
|
||||||
|
|
||||||
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
||||||
type Scalar = u8;
|
type Scalar = u8;
|
||||||
|
|
||||||
@@ -112,265 +115,3 @@ impl VecZnxBig<FFT64> {
|
|||||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxBigOps<B: Backend> {
|
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
|
||||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
|
|
||||||
|
|
||||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
|
||||||
///
|
|
||||||
/// Behavior: takes ownership of the backing array.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `cols`: the number of polynomials..
|
|
||||||
/// * `size`: the number of polynomials per column.
|
|
||||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
|
||||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
|
|
||||||
|
|
||||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
|
||||||
///
|
|
||||||
/// Behavior: the backing array is only borrowed.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `cols`: the number of polynomials..
|
|
||||||
/// * `size`: the number of polynomials per column.
|
|
||||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
|
||||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
|
||||||
|
|
||||||
/// Returns the minimum number of bytes necessary to allocate
|
|
||||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
|
||||||
|
|
||||||
/// Subtracts `a` to `b` and stores the result on `b`.
|
|
||||||
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
|
|
||||||
|
|
||||||
/// Subtracts `b` to `a` and stores the result on `c`.
|
|
||||||
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
|
|
||||||
|
|
||||||
/// Adds `a` to `b` and stores the result on `c`.
|
|
||||||
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
|
|
||||||
|
|
||||||
/// Adds `a` to `b` and stores the result on `b`.
|
|
||||||
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
|
|
||||||
|
|
||||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
|
||||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
|
||||||
|
|
||||||
/// Normalizes `a` and stores the result on `b`.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `log_base2k`: normalization basis.
|
|
||||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
|
||||||
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<B>, tmp_bytes: &mut [u8]);
|
|
||||||
|
|
||||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k].
|
|
||||||
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
|
|
||||||
|
|
||||||
/// Normalize `a`, taking into account column interleaving and stores the result on `b`.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `log_base2k`: normalization basis.
|
|
||||||
/// * `a_range_begin`: column to start.
|
|
||||||
/// * `a_range_end`: column to end.
|
|
||||||
/// * `a_range_step`: column step size.
|
|
||||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes].
|
|
||||||
fn vec_znx_big_range_normalize_base2k(
|
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
b: &mut VecZnx,
|
|
||||||
a: &VecZnxBig<B>,
|
|
||||||
a_range_begin: usize,
|
|
||||||
a_range_xend: usize,
|
|
||||||
a_range_step: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
|
||||||
fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
|
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
|
||||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|
||||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
|
|
||||||
VecZnxBig::new(self, cols, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
|
||||||
VecZnxBig::from_bytes(self, cols, size, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
|
||||||
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
|
||||||
VecZnxBig::bytes_of(self, cols, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_sub_small_a(
|
|
||||||
self.ptr,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
a.as_ptr(),
|
|
||||||
a.poly_count() as u64,
|
|
||||||
a.n() as u64,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_sub_small_a(
|
|
||||||
self.ptr,
|
|
||||||
c.ptr as *mut vec_znx_big_t,
|
|
||||||
c.poly_count() as u64,
|
|
||||||
a.as_ptr(),
|
|
||||||
a.poly_count() as u64,
|
|
||||||
a.n() as u64,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_add_small(
|
|
||||||
self.ptr,
|
|
||||||
c.ptr as *mut vec_znx_big_t,
|
|
||||||
c.poly_count() as u64,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
a.as_ptr(),
|
|
||||||
a.poly_count() as u64,
|
|
||||||
a.n() as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_add_small(
|
|
||||||
self.ptr,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
a.as_ptr(),
|
|
||||||
a.poly_count() as u64,
|
|
||||||
a.n() as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
|
||||||
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<FFT64>, tmp_bytes: &mut [u8]) {
|
|
||||||
debug_assert!(
|
|
||||||
tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self),
|
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
|
||||||
tmp_bytes.len(),
|
|
||||||
Self::vec_znx_big_normalize_tmp_bytes(self)
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_normalize_base2k(
|
|
||||||
self.ptr,
|
|
||||||
log_base2k as u64,
|
|
||||||
b.as_mut_ptr(),
|
|
||||||
b.size() as u64,
|
|
||||||
b.n() as u64,
|
|
||||||
a.ptr as *mut vec_znx_big_t,
|
|
||||||
a.size() as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
|
|
||||||
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_range_normalize_base2k(
|
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
res: &mut VecZnx,
|
|
||||||
a: &VecZnxBig<FFT64>,
|
|
||||||
a_range_begin: usize,
|
|
||||||
a_range_xend: usize,
|
|
||||||
a_range_step: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
debug_assert!(
|
|
||||||
tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
|
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
|
|
||||||
tmp_bytes.len(),
|
|
||||||
Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_range_normalize_base2k(
|
|
||||||
self.ptr,
|
|
||||||
log_base2k as u64,
|
|
||||||
res.as_mut_ptr(),
|
|
||||||
res.size() as u64,
|
|
||||||
res.n() as u64,
|
|
||||||
a.ptr as *mut vec_znx_big_t,
|
|
||||||
a_range_begin as u64,
|
|
||||||
a_range_xend as u64,
|
|
||||||
a_range_step as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_automorphism(
|
|
||||||
self.ptr,
|
|
||||||
gal_el,
|
|
||||||
b.ptr as *mut vec_znx_big_t,
|
|
||||||
b.poly_count() as u64,
|
|
||||||
a.ptr as *mut vec_znx_big_t,
|
|
||||||
a.poly_count() as u64,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig<FFT64>) {
|
|
||||||
unsafe {
|
|
||||||
vec_znx_big::vec_znx_big_automorphism(
|
|
||||||
self.ptr,
|
|
||||||
gal_el,
|
|
||||||
a.ptr as *mut vec_znx_big_t,
|
|
||||||
a.poly_count() as u64,
|
|
||||||
a.ptr as *mut vec_znx_big_t,
|
|
||||||
a.poly_count() as u64,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
339
base2k/src/vec_znx_big_ops.rs
Normal file
339
base2k/src/vec_znx_big_ops.rs
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||||
|
use crate::ffi::{vec_znx, vec_znx_big};
|
||||||
|
use crate::internals::{apply_binary_op, ffi_ternary_op_factory};
|
||||||
|
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
|
||||||
|
|
||||||
|
pub trait VecZnxBigOps<B: Backend> {
|
||||||
|
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||||
|
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
|
||||||
|
|
||||||
|
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||||
|
///
|
||||||
|
/// Behavior: takes ownership of the backing array.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `cols`: the number of polynomials..
|
||||||
|
/// * `size`: the number of polynomials per column.
|
||||||
|
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||||
|
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||||
|
|
||||||
|
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||||
|
///
|
||||||
|
/// Behavior: the backing array is only borrowed.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `cols`: the number of polynomials..
|
||||||
|
/// * `size`: the number of polynomials per column.
|
||||||
|
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||||
|
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||||
|
|
||||||
|
/// Returns the minimum number of bytes necessary to allocate
|
||||||
|
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||||
|
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||||
|
|
||||||
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
|
fn vec_znx_big_add(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
|
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
|
||||||
|
|
||||||
|
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||||
|
fn vec_znx_big_sub(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Subtracts `a` to `b` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Subtracts `b` to `a` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Subtracts `b` to `a` and stores the result on `c`.
|
||||||
|
fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig<B>, a: &VecZnx, b: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Subtracts `a` to `b` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
|
||||||
|
|
||||||
|
/// Subtracts `b` to `a` and stores the result on `c`.
|
||||||
|
fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig<B>, a: &VecZnxBig<B>, b: &VecZnx);
|
||||||
|
|
||||||
|
/// Subtracts `b` to `a` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig<B>, a: &VecZnx);
|
||||||
|
|
||||||
|
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||||
|
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
|
/// Normalizes `a` and stores the result on `b`.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `log_base2k`: normalization basis.
|
||||||
|
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
||||||
|
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<B>, tmp_bytes: &mut [u8]);
|
||||||
|
|
||||||
|
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k].
|
||||||
|
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
|
/// Normalize `a`, taking into account column interleaving and stores the result on `b`.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `log_base2k`: normalization basis.
|
||||||
|
/// * `a_range_begin`: column to start.
|
||||||
|
/// * `a_range_end`: column to end.
|
||||||
|
/// * `a_range_step`: column step size.
|
||||||
|
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes].
|
||||||
|
fn vec_znx_big_range_normalize_base2k(
|
||||||
|
&self,
|
||||||
|
log_base2k: usize,
|
||||||
|
b: &mut VecZnx,
|
||||||
|
a: &VecZnxBig<B>,
|
||||||
|
a_range_begin: usize,
|
||||||
|
a_range_xend: usize,
|
||||||
|
a_range_step: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||||
|
fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig<B>, a: &VecZnxBig<B>);
|
||||||
|
|
||||||
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||||
|
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||||
|
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
|
||||||
|
VecZnxBig::new(self, cols, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||||
|
VecZnxBig::from_bytes(self, cols, size, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||||
|
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||||
|
VecZnxBig::bytes_of(self, cols, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_add(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>) {
|
||||||
|
let op = ffi_ternary_op_factory(
|
||||||
|
self.ptr,
|
||||||
|
c.size(),
|
||||||
|
c.sl(),
|
||||||
|
a.size(),
|
||||||
|
a.sl(),
|
||||||
|
b.size(),
|
||||||
|
b.sl(),
|
||||||
|
vec_znx::vec_znx_add,
|
||||||
|
);
|
||||||
|
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnxBig<FFT64>, false>(self, c, a, b, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>) {
|
||||||
|
let op = ffi_ternary_op_factory(
|
||||||
|
self.ptr,
|
||||||
|
c.size(),
|
||||||
|
c.sl(),
|
||||||
|
a.size(),
|
||||||
|
a.sl(),
|
||||||
|
b.size(),
|
||||||
|
b.sl(),
|
||||||
|
vec_znx::vec_znx_sub,
|
||||||
|
);
|
||||||
|
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnxBig<FFT64>, true>(self, c, a, b, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>, b: &VecZnx) {
|
||||||
|
let op = ffi_ternary_op_factory(
|
||||||
|
self.ptr,
|
||||||
|
c.size(),
|
||||||
|
c.sl(),
|
||||||
|
a.size(),
|
||||||
|
a.sl(),
|
||||||
|
b.size(),
|
||||||
|
b.sl(),
|
||||||
|
vec_znx::vec_znx_sub,
|
||||||
|
);
|
||||||
|
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnxBig<FFT64>, VecZnx, true>(self, c, a, b, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
|
||||||
|
let op = ffi_ternary_op_factory(
|
||||||
|
self.ptr,
|
||||||
|
c.size(),
|
||||||
|
c.sl(),
|
||||||
|
a.size(),
|
||||||
|
a.sl(),
|
||||||
|
b.size(),
|
||||||
|
b.sl(),
|
||||||
|
vec_znx::vec_znx_sub,
|
||||||
|
);
|
||||||
|
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnx, VecZnxBig<FFT64>, true>(self, c, a, b, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig<FFT64>, a: &VecZnx, b: &VecZnxBig<FFT64>) {
|
||||||
|
let op = ffi_ternary_op_factory(
|
||||||
|
self.ptr,
|
||||||
|
c.size(),
|
||||||
|
c.sl(),
|
||||||
|
a.size(),
|
||||||
|
a.sl(),
|
||||||
|
b.size(),
|
||||||
|
b.sl(),
|
||||||
|
vec_znx::vec_znx_add,
|
||||||
|
);
|
||||||
|
apply_binary_op::<FFT64, VecZnxBig<FFT64>, VecZnx, VecZnxBig<FFT64>, false>(self, c, a, b, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnx) {
|
||||||
|
unsafe {
|
||||||
|
let b_ptr: *mut VecZnxBig<FFT64> = b as *mut VecZnxBig<FFT64>;
|
||||||
|
Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||||
|
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig<FFT64>, tmp_bytes: &mut [u8]) {
|
||||||
|
debug_assert!(
|
||||||
|
tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self),
|
||||||
|
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
||||||
|
tmp_bytes.len(),
|
||||||
|
Self::vec_znx_big_normalize_tmp_bytes(self)
|
||||||
|
);
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_alignement(tmp_bytes.as_ptr())
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
vec_znx_big::vec_znx_big_normalize_base2k(
|
||||||
|
self.ptr,
|
||||||
|
log_base2k as u64,
|
||||||
|
b.as_mut_ptr(),
|
||||||
|
b.size() as u64,
|
||||||
|
b.n() as u64,
|
||||||
|
a.ptr as *mut vec_znx_big_t,
|
||||||
|
a.size() as u64,
|
||||||
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
|
||||||
|
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_range_normalize_base2k(
|
||||||
|
&self,
|
||||||
|
log_base2k: usize,
|
||||||
|
res: &mut VecZnx,
|
||||||
|
a: &VecZnxBig<FFT64>,
|
||||||
|
a_range_begin: usize,
|
||||||
|
a_range_xend: usize,
|
||||||
|
a_range_step: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
debug_assert!(
|
||||||
|
tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
|
||||||
|
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
|
||||||
|
tmp_bytes.len(),
|
||||||
|
Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
|
||||||
|
);
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_alignement(tmp_bytes.as_ptr())
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
vec_znx_big::vec_znx_big_range_normalize_base2k(
|
||||||
|
self.ptr,
|
||||||
|
log_base2k as u64,
|
||||||
|
res.as_mut_ptr(),
|
||||||
|
res.size() as u64,
|
||||||
|
res.n() as u64,
|
||||||
|
a.ptr as *mut vec_znx_big_t,
|
||||||
|
a_range_begin as u64,
|
||||||
|
a_range_xend as u64,
|
||||||
|
a_range_step as u64,
|
||||||
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>) {
|
||||||
|
unsafe {
|
||||||
|
vec_znx_big::vec_znx_big_automorphism(
|
||||||
|
self.ptr,
|
||||||
|
gal_el,
|
||||||
|
b.ptr as *mut vec_znx_big_t,
|
||||||
|
b.poly_count() as u64,
|
||||||
|
a.ptr as *mut vec_znx_big_t,
|
||||||
|
a.poly_count() as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig<FFT64>) {
|
||||||
|
unsafe {
|
||||||
|
vec_znx_big::vec_znx_big_automorphism(
|
||||||
|
self.ptr,
|
||||||
|
gal_el,
|
||||||
|
a.ptr as *mut vec_znx_big_t,
|
||||||
|
a.poly_count() as u64,
|
||||||
|
a.ptr as *mut vec_znx_big_t,
|
||||||
|
a.poly_count() as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::ffi::module::MODULE;
|
use crate::ffi::module::MODULE;
|
||||||
use crate::ffi::vec_znx;
|
use crate::ffi::vec_znx;
|
||||||
use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout};
|
use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1};
|
||||||
use std::cmp::min;
|
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree};
|
||||||
pub trait VecZnxOps {
|
pub trait VecZnxOps {
|
||||||
/// Allocates a new [VecZnx].
|
/// Allocates a new [VecZnx].
|
||||||
///
|
///
|
||||||
@@ -125,7 +125,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
b.sl(),
|
b.sl(),
|
||||||
vec_znx::vec_znx_add,
|
vec_znx::vec_znx_add,
|
||||||
);
|
);
|
||||||
apply_binary_op::<B, VecZnx, false>(self, c, a, b, op);
|
apply_binary_op::<B, VecZnx, VecZnx, VecZnx, false>(self, c, a, b, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||||
@@ -146,7 +146,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
b.sl(),
|
b.sl(),
|
||||||
vec_znx::vec_znx_sub,
|
vec_znx::vec_znx_sub,
|
||||||
);
|
);
|
||||||
apply_binary_op::<B, VecZnx, true>(self, c, a, b, op);
|
apply_binary_op::<B, VecZnx, VecZnx, VecZnx, true>(self, c, a, b, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||||
@@ -298,56 +298,11 @@ fn ffi_ternary_op_factory(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ffi_binary_op_factory_type_0(
|
|
||||||
module_ptr: *const MODULE,
|
|
||||||
b_size: usize,
|
|
||||||
b_sl: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_sl: usize,
|
|
||||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64),
|
|
||||||
) -> impl Fn(&mut [i64], &[i64]) {
|
|
||||||
move |bv: &mut [i64], av: &[i64]| unsafe {
|
|
||||||
op_fn(
|
|
||||||
module_ptr,
|
|
||||||
bv.as_mut_ptr(),
|
|
||||||
b_size as u64,
|
|
||||||
b_sl as u64,
|
|
||||||
av.as_ptr(),
|
|
||||||
a_size as u64,
|
|
||||||
a_sl as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ffi_binary_op_factory_type_1(
|
|
||||||
module_ptr: *const MODULE,
|
|
||||||
k: i64,
|
|
||||||
b_size: usize,
|
|
||||||
b_sl: usize,
|
|
||||||
a_size: usize,
|
|
||||||
a_sl: usize,
|
|
||||||
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64),
|
|
||||||
) -> impl Fn(&mut [i64], &[i64]) {
|
|
||||||
move |bv: &mut [i64], av: &[i64]| unsafe {
|
|
||||||
op_fn(
|
|
||||||
module_ptr,
|
|
||||||
k,
|
|
||||||
bv.as_mut_ptr(),
|
|
||||||
b_size as u64,
|
|
||||||
b_sl as u64,
|
|
||||||
av.as_ptr(),
|
|
||||||
a_size as u64,
|
|
||||||
a_sl as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::internals::znx_post_process_ternary_op;
|
||||||
Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx,
|
use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx};
|
||||||
znx_post_process_ternary_op,
|
|
||||||
};
|
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
@@ -623,7 +578,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b);
|
znx_post_process_ternary_op::<VecZnx, VecZnx, VecZnx, NEGATE>(&mut c_want, &a, &b);
|
||||||
|
|
||||||
assert_eq!(c_have.raw(), c_want.raw());
|
assert_eq!(c_have.raw(), c_want.raw());
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user