centralized sensitive code into VecZnxLayout

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 12:34:42 +02:00
parent 5841845e22
commit 6532f30f66
12 changed files with 218 additions and 322 deletions

View File

@@ -2,7 +2,7 @@ use crate::Backend;
use crate::cast_mut;
use crate::ffi::vec_znx;
use crate::ffi::znx;
use crate::{Infos, Module};
use crate::{Infos, Module, VecZnxLayout};
use crate::{alloc_aligned, assert_alignement};
use itertools::izip;
use std::cmp::min;
@@ -35,157 +35,6 @@ pub struct VecZnx {
pub ptr: *mut i64,
}
pub fn bytes_of_vec_znx(n: usize, cols: usize, limbs: usize) -> usize {
n * cols * limbs * size_of::<i64>()
}
impl VecZnx {
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the limbs of data is equal to [Self::bytes_of].
pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert!(bytes.len() >= Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
pub fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize {
bytes_of_vec_znx(n, cols, limbs)
}
pub fn copy_from(&mut self, a: &Self) {
copy_vec_znx_from(self, a);
}
pub fn borrowing(&self) -> bool {
self.data.len() == 0
}
/// Total limbs is [Self::n()] * [Self::poly_count()].
pub fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.poly_count()) }
}
/// Returns a reference to backend slice of the receiver.
/// Total size is [Self::n()] * [Self::poly_count()].
pub fn raw_mut(&mut self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.poly_count()) }
}
/// Returns a non-mutable pointer to the backedn slice of the receiver.
pub fn as_ptr(&self) -> *const i64 {
self.ptr
}
/// Returns a mutable pointer to the backedn slice of the receiver.
pub fn as_mut_ptr(&mut self) -> *mut i64 {
self.ptr
}
/// Returns a non-mutable pointer starting a the (i, j)-th small poly.
pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.ptr.wrapping_add(offset)
}
/// Returns a non-mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb(&self, i: usize) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) }
}
/// Returns a non-mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly(&self, i: usize, j: usize) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) }
}
/// Returns a mutable pointer starting a the (i, j)-th small poly.
pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.ptr.wrapping_add(offset)
}
/// Returns a mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) }
}
/// Returns a mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) }
}
pub fn zero(&mut self) {
unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) }
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
}
pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
}
pub fn switch_degree(&self, a: &mut Self) {
switch_degree(a, self)
}
// Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
}
}
impl Infos for VecZnx {
fn n(&self) -> usize {
self.n
@@ -212,6 +61,18 @@ impl Infos for VecZnx {
}
}
impl VecZnxLayout for VecZnx {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr
}
}
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
@@ -271,6 +132,83 @@ impl VecZnx {
.for_each(|x: &mut i64| *x &= mask)
}
}
fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize {
n * cols * limbs * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the limbs of data is equal to [Self::bytes_of].
pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert!(bytes.len() >= Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
pub fn copy_from(&mut self, a: &Self) {
copy_vec_znx_from(self, a);
}
pub fn borrowing(&self) -> bool {
self.data.len() == 0
}
pub fn zero(&mut self) {
unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) }
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
}
pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
}
pub fn switch_degree(&self, a: &mut Self) {
switch_degree(a, self)
}
// Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
}
}
pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) {
@@ -395,6 +333,9 @@ pub trait VecZnxOps {
/// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials).
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx;
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx;
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnx::from_bytes].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
@@ -457,7 +398,15 @@ impl<B: Backend> VecZnxOps for Module<B> {
}
fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize {
bytes_of_vec_znx(self.n(), cols, limbs)
VecZnx::bytes_of(self.n(), cols, limbs)
}
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes(self.n(), cols, limbs, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes)
}
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {