From 1f384ce54dd8223ff9f7c84dac80da696da34747 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 15:21:24 +0200 Subject: [PATCH] Added vec_znx_add/sub_scalar & available on Scratch --- base2k/src/lib.rs | 24 ++++++++---- base2k/src/vec_znx_ops.rs | 77 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 10 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 450a69f..bb8ce55 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -150,19 +150,27 @@ impl Scratch { unsafe { &mut *(data as *mut [u8] as *mut Self) } } - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { - let ptr = data.as_mut_ptr(); - let self_len = data.len(); + #[allow(dead_code)] + fn available(&self) -> usize { + let ptr: *const u8 = self.data.as_ptr(); + let self_len: usize = self.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } - let aligned_offset = ptr.align_offset(DEFAULTALIGN); - let aligned_len = self_len.saturating_sub(aligned_offset); + fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); if let Some(rem_len) = aligned_len.checked_sub(take_len) { unsafe { - let rem_ptr = ptr.add(aligned_offset).add(take_len); - let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); - let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); return (take_slice, rem_slice); } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index c80e9f1..f57e99f 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,7 @@ use crate::ffi::vec_znx; use crate::{ - Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, ZnxZero, }; use itertools::izip; use std::cmp::min; @@ -51,12 +52,18 @@ pub trait VecZnxOps { A: VecZnxToRef, B: VecZnxToRef; - /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + /// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`. fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; + /// Adds the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where @@ -80,6 +87,12 @@ pub trait VecZnxOps { R: VecZnxToMut, A: VecZnxToRef; + /// Subtracts the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + // Negates the selected column of `a` and stores the result in `res_col` of `res`. fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -282,6 +295,36 @@ impl VecZnxOps for Module { } } + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxToMut, @@ -315,6 +358,36 @@ impl VecZnxOps for Module { } } + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut,