mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
more refactoring
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
||||
use crate::{Backend, FFT64, Infos, Module, VecZnxBig, VecZnxLayout, assert_alignement};
|
||||
use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
|
||||
use crate::{VecZnx, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -14,15 +14,17 @@ pub struct VecZnxDft<B: Backend> {
|
||||
pub _marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl VecZnxDft<FFT64> {
|
||||
pub fn new(module: &Module<FFT64>, cols: usize, limbs: usize) -> Self {
|
||||
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
}
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_vec_znx_dft(cols, limbs));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, limbs));
|
||||
let ptr: *mut Self::Scalar = data.as_mut_ptr();
|
||||
Self {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
@@ -33,19 +35,19 @@ impl VecZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<FFT64>, cols: usize, limbs: usize) -> usize {
|
||||
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
|
||||
unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
|
||||
pub fn from_bytes(module: &Module<FFT64>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
|
||||
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs));
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
@@ -60,12 +62,12 @@ impl VecZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(module: &Module<FFT64>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(cols, limbs));
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
@@ -77,12 +79,14 @@ impl VecZnxDft<FFT64> {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDft<B> {
|
||||
/// Cast a [VecZnxDft] into a [VecZnxBig].
|
||||
/// The returned [VecZnxBig] shares the backing array
|
||||
/// with the original [VecZnxDft].
|
||||
pub fn as_vec_znx_big(&mut self) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::<FFT64> {
|
||||
pub fn as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
||||
VecZnxBig::<B> {
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
n: self.n,
|
||||
@@ -91,13 +95,9 @@ impl VecZnxDft<FFT64> {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Infos for VecZnxDft<B> {
|
||||
impl<B: Backend> ZnxInfos for VecZnxDft<B> {
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
@@ -123,7 +123,7 @@ impl<B: Backend> Infos for VecZnxDft<B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxLayout for VecZnxDft<FFT64> {
|
||||
impl ZnxLayout for VecZnxDft<FFT64> {
|
||||
type Scalar = f64;
|
||||
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
@@ -135,6 +135,12 @@ impl VecZnxLayout for VecZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxDft<FFT64> {
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft<B>;
|
||||
@@ -314,7 +320,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxLayout, VecZnxOps, alloc_aligned};
|
||||
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user