more refactoring

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 13:19:22 +02:00
parent 6532f30f66
commit 54148acf6b
25 changed files with 294 additions and 256 deletions

View File

@@ -1,8 +1,9 @@
use crate::Backend;
use crate::ZnxBase;
use crate::cast_mut;
use crate::ffi::vec_znx;
use crate::ffi::znx;
use crate::{Infos, Module, VecZnxLayout};
use crate::{Module, ZnxInfos, ZnxLayout};
use crate::{alloc_aligned, assert_alignement};
use itertools::izip;
use std::cmp::min;
@@ -35,7 +36,7 @@ pub struct VecZnx {
pub ptr: *mut i64,
}
impl Infos for VecZnx {
impl ZnxInfos for VecZnx {
fn n(&self) -> usize {
self.n
}
@@ -61,7 +62,7 @@ impl Infos for VecZnx {
}
}
impl VecZnxLayout for VecZnx {
impl ZnxLayout for VecZnx {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
@@ -84,9 +85,12 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
data_b[..size].copy_from_slice(&data_a[..size])
}
impl VecZnx {
impl<B: Backend> ZnxBase<B> for VecZnx {
type Scalar = i64;
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
pub fn new(n: usize, cols: usize, limbs: usize) -> Self {
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(n > 0);
@@ -94,7 +98,7 @@ impl VecZnx {
assert!(cols > 0);
assert!(limbs > 0);
}
let mut data: Vec<i64> = alloc_aligned::<i64>(n * cols * limbs);
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, limbs));
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
@@ -105,6 +109,57 @@ impl VecZnx {
}
}
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
module.n() * cols * limbs * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the limbs of data is equal to [Self::bytes_of].
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert!(bytes.len() >= Self::bytes_of(module, cols, limbs));
assert_alignement(bytes.as_ptr());
}
Self {
n: module.n(),
cols: cols,
limbs: limbs,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
}
impl VecZnx {
/// Truncates the precision of the [VecZnx] by k bits.
///
/// # Arguments
@@ -133,54 +188,6 @@ impl VecZnx {
}
}
fn bytes_of(n: usize, cols: usize, limbs: usize) -> usize {
n * cols * limbs * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the limbs of data is equal to [Self::bytes_of].
pub fn from_bytes(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert!(bytes.len() >= Self::bytes_of(n, cols, limbs));
assert_alignement(bytes.as_ptr());
}
Self {
n: n,
cols: cols,
limbs: limbs,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
pub fn copy_from(&mut self, a: &Self) {
copy_vec_znx_from(self, a);
}
@@ -394,19 +401,19 @@ pub trait VecZnxOps {
impl<B: Backend> VecZnxOps for Module<B> {
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx {
VecZnx::new(self.n(), cols, limbs)
VecZnx::new(self, cols, limbs)
}
fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize {
VecZnx::bytes_of(self.n(), cols, limbs)
VecZnx::bytes_of(self, cols, limbs)
}
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes(self.n(), cols, limbs, bytes)
VecZnx::from_bytes(self, cols, limbs, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self.n(), cols, limbs, tmp_bytes)
VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes)
}
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {