mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added VecZnxBig<FFT64> ops
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use crate::{Backend, Module, assert_alignement, cast_mut};
|
||||
use itertools::izip;
|
||||
use std::cmp::{max, min};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
/// Returns the ring degree of the polynomials.
|
||||
@@ -243,105 +243,3 @@ where
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn znx_post_process_ternary_op<T: ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
|
||||
where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
assert_ne!(b.as_ptr(), c.as_ptr());
|
||||
assert_ne!(a.as_ptr(), c.as_ptr());
|
||||
}
|
||||
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let max_ab_cols: usize = max(a_cols, b_cols);
|
||||
|
||||
// Copies shared shared cols between (c, max(a, b))
|
||||
if a_cols != b_cols {
|
||||
let mut x: &T = a;
|
||||
if a_cols < b_cols {
|
||||
x = b;
|
||||
}
|
||||
|
||||
let min_size = min(c.size(), x.size());
|
||||
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||
(0..min_size).for_each(|j| {
|
||||
c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j));
|
||||
if NEGATE {
|
||||
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
(min_size..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Zeroes the cols of c > max(a, b).
|
||||
if c_cols > max_ab_cols {
|
||||
(max_ab_cols..c_cols).for_each(|i| {
|
||||
(0..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_binary_op<B: Backend, T: ZnxBasics + ZnxLayout, const NEGATE: bool>(
|
||||
module: &Module<B>,
|
||||
c: &mut T,
|
||||
a: &T,
|
||||
b: &T,
|
||||
op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]),
|
||||
) where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(c.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let min_cols: usize = min(c_cols, min_ab_cols);
|
||||
// Applies over shared cols between (a, b, c)
|
||||
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
|
||||
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
|
||||
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
|
||||
znx_post_process_ternary_op::<T, NEGATE>(c, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
|
||||
module: &Module<B>,
|
||||
b: &mut T,
|
||||
a: &T,
|
||||
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
|
||||
) where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
}
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let min_cols: usize = min(a_cols, b_cols);
|
||||
// Applies over the shared cols between (a, b)
|
||||
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
|
||||
// Zeroes the remaining cols of b.
|
||||
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user