Generalized apply_binary_op & apply_unary_op

This commit is contained in:
Jean-Philippe Bossuat
2025-04-29 14:33:07 +02:00
parent d86d6b6ee8
commit 3ee69866bd
2 changed files with 60 additions and 51 deletions

View File

@@ -244,7 +244,7 @@ where
});
}
pub fn znx_post_process_ternary_op<T: ZnxInfos + ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
pub fn znx_post_process_ternary_op<T: ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
where
<T as ZnxLayout>::Scalar: IntegerType,
{
@@ -292,3 +292,56 @@ where
});
}
}
#[inline(always)]
pub fn apply_binary_op<B: Backend, T: ZnxBasics + ZnxLayout, const NEGATE: bool>(
module: &Module<B>,
c: &mut T,
a: &T,
b: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(c.n(), module.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let min_cols: usize = min(c_cols, min_ab_cols);
// Applies over shared cols between (a, b, c)
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
znx_post_process_ternary_op::<T, NEGATE>(c, a, b);
}
}
#[inline(always)]
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
module: &Module<B>,
b: &mut T,
a: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}

View File

@@ -1,6 +1,6 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx;
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op};
use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout};
use std::cmp::min;
pub trait VecZnxOps {
/// Allocates a new [VecZnx].
@@ -125,7 +125,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
b.sl(),
vec_znx::vec_znx_add,
);
vec_znx_apply_binary_op::<B, false>(self, c, a, b, op);
apply_binary_op::<B, VecZnx, false>(self, c, a, b, op);
}
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
@@ -146,7 +146,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
b.sl(),
vec_znx::vec_znx_sub,
);
vec_znx_apply_binary_op::<B, true>(self, c, a, b, op);
apply_binary_op::<B, VecZnx, true>(self, c, a, b, op);
}
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
@@ -172,7 +172,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
a.sl(),
vec_znx::vec_znx_negate,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
apply_unary_op::<B, VecZnx>(self, b, a, op);
}
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
@@ -192,7 +192,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
a.sl(),
vec_znx::vec_znx_rotate,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
apply_unary_op::<B, VecZnx>(self, b, a, op);
}
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
@@ -212,7 +212,7 @@ impl<B: Backend> VecZnxOps for Module<B> {
a.sl(),
vec_znx::vec_znx_automorphism,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
apply_unary_op::<B, VecZnx>(self, b, a, op);
}
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
@@ -342,50 +342,6 @@ fn ffi_binary_op_factory_type_1(
}
}
#[inline(always)]
pub fn vec_znx_apply_binary_op<B: Backend, const NEGATE: bool>(
module: &Module<B>,
c: &mut VecZnx,
a: &VecZnx,
b: &VecZnx,
op: impl Fn(&mut [i64], &[i64], &[i64]),
) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(c.n(), module.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let min_cols: usize = min(c_cols, min_ab_cols);
// Applies over shared cols between (a, b, c)
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
znx_post_process_ternary_op::<VecZnx, NEGATE>(c, a, b);
}
}
#[inline(always)]
pub fn vec_znx_apply_unary_op<B: Backend>(module: &Module<B>, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}
#[cfg(test)]
mod tests {
use crate::{