more refactoring

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 13:19:22 +02:00
parent 6532f30f66
commit 54148acf6b
25 changed files with 294 additions and 256 deletions

View File

@@ -1,5 +1,5 @@
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::{Backend, FFT64, Infos, Module, VecZnx, VecZnxDft, VecZnxLayout, alloc_aligned, assert_alignement};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxBase, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
use std::marker::PhantomData;
pub struct VecZnxBig<B: Backend> {
@@ -10,16 +10,17 @@ pub struct VecZnxBig<B: Backend> {
pub limbs: usize,
pub _marker: PhantomData<B>,
}
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
type Scalar = u8;
impl VecZnxBig<FFT64> {
pub fn new(module: &Module<FFT64>, cols: usize, limbs: usize) -> Self {
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
}
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_vec_znx_big(cols, limbs));
let ptr: *mut u8 = data.as_mut_ptr();
let mut data: Vec<Self::Scalar> = alloc_aligned::<u8>(Self::bytes_of(module, cols, limbs));
let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self {
data: data,
ptr: ptr,
@@ -30,15 +31,19 @@ impl VecZnxBig<FFT64> {
}
}
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols }
}
/// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
pub fn from_bytes(module: &Module<FFT64>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs));
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
assert_alignement(bytes.as_ptr())
};
unsafe {
@@ -53,12 +58,12 @@ impl VecZnxBig<FFT64> {
}
}
pub fn from_bytes_borrow(module: &Module<FFT64>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(limbs > 0);
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols, limbs));
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
assert_alignement(bytes.as_ptr());
}
Self {
@@ -70,24 +75,9 @@ impl VecZnxBig<FFT64> {
_marker: PhantomData,
}
}
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64> {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
cols: self.cols,
limbs: self.limbs,
_marker: self._marker,
}
}
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
}
}
impl<B: Backend> Infos for VecZnxBig<B> {
impl<B: Backend> ZnxInfos for VecZnxBig<B> {
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
@@ -113,7 +103,7 @@ impl<B: Backend> Infos for VecZnxBig<B> {
}
}
impl VecZnxLayout for VecZnxBig<FFT64> {
impl ZnxLayout for VecZnxBig<FFT64> {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
@@ -125,6 +115,12 @@ impl VecZnxLayout for VecZnxBig<FFT64> {
}
}
impl VecZnxBig<FFT64> {
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
}
}
pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig<B>;
@@ -220,7 +216,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
}
fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, limbs as u64) as usize * cols }
VecZnxBig::bytes_of(self, cols, limbs)
}
/// [VecZnxBig] (3 cols and 4 limbs)