From 3ee69866bd5936bb6c5ef390e5bbcac4bf45d562 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 29 Apr 2025 14:33:07 +0200 Subject: [PATCH] Generalized apply_binary_op & apply_unary_op --- base2k/src/commons.rs | 55 +++++++++++++++++++++++++++++++++++++- base2k/src/vec_znx_ops.rs | 56 +++++---------------------------------- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 1d7a0c9..cfae556 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -244,7 +244,7 @@ where }); } -pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) +pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) where ::Scalar: IntegerType, { @@ -292,3 +292,56 @@ where }); } } + +#[inline(always)] +pub fn apply_binary_op( + module: &Module, + c: &mut T, + a: &T, + b: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn apply_unary_op( + module: &Module, + b: &mut T, + a: &T, + op: impl Fn(&mut [T::Scalar], &[T::Scalar]), +) where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 4c8409d..573e5b1 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,6 @@ use crate::ffi::module::MODULE; use crate::ffi::vec_znx; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op}; +use crate::{apply_binary_op, apply_unary_op, switch_degree, znx_post_process_ternary_op, Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout}; use std::cmp::min; pub trait VecZnxOps { /// Allocates a new [VecZnx]. @@ -125,7 +125,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_add, ); - vec_znx_apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -146,7 +146,7 @@ impl VecZnxOps for Module { b.sl(), vec_znx::vec_znx_sub, ); - vec_znx_apply_binary_op::(self, c, a, b, op); + apply_binary_op::(self, c, a, b, op); } fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { @@ -172,7 +172,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_negate, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { @@ -192,7 +192,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_rotate, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { @@ -212,7 +212,7 @@ impl VecZnxOps for Module { a.sl(), vec_znx::vec_znx_automorphism, ); - vec_znx_apply_unary_op::(self, b, a, op); + apply_unary_op::(self, b, a, op); } fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { @@ -342,50 +342,6 @@ fn ffi_binary_op_factory_type_1( } } -#[inline(always)] -pub fn vec_znx_apply_binary_op( - module: &Module, - c: &mut VecZnx, - a: &VecZnx, - b: &VecZnx, - op: impl Fn(&mut [i64], &[i64], &[i64]), -) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - assert_eq!(c.n(), module.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let c_cols: usize = c.cols(); - let min_ab_cols: usize = min(a_cols, b_cols); - let min_cols: usize = min(c_cols, min_ab_cols); - // Applies over shared cols between (a, b, c) - (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); - // Copies/Negates/Zeroes the remaining cols if op is not inplace. - if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { - znx_post_process_ternary_op::(c, a, b); - } -} - -#[inline(always)] -pub fn vec_znx_apply_unary_op(module: &Module, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} - #[cfg(test)] mod tests { use crate::{