Added vec_znx_add/sub_scalar & available on Scratch

This commit is contained in:
Jean-Philippe Bossuat
2025-05-08 15:21:24 +02:00
parent 8b3b2e4b9c
commit 1f384ce54d
2 changed files with 91 additions and 10 deletions

View File

@@ -150,19 +150,27 @@ impl Scratch {
unsafe { &mut *(data as *mut [u8] as *mut Self) }
}
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr = data.as_mut_ptr();
let self_len = data.len();
#[allow(dead_code)]
fn available(&self) -> usize {
let ptr: *const u8 = self.data.as_ptr();
let self_len: usize = self.data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
self_len.saturating_sub(aligned_offset)
}
let aligned_offset = ptr.align_offset(DEFAULTALIGN);
let aligned_len = self_len.saturating_sub(aligned_offset);
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr: *mut u8 = data.as_mut_ptr();
let self_len: usize = data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
unsafe {
let rem_ptr = ptr.add(aligned_offset).add(take_len);
let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
return (take_slice, rem_slice);
}

View File

@@ -1,6 +1,7 @@
use crate::ffi::vec_znx;
use crate::{
Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use std::cmp::min;
@@ -51,12 +52,18 @@ pub trait VecZnxOps {
A: VecZnxToRef,
B: VecZnxToRef;
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
@@ -80,6 +87,12 @@ pub trait VecZnxOps {
R: VecZnxToMut,
A: VecZnxToRef;
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
@@ -282,6 +295,36 @@ impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
}
}
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
@@ -315,6 +358,36 @@ impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
}
}
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,