mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
improved alligned vec allocation & fixed vec_znx calls, fixed auto dft test
This commit is contained in:
@@ -82,14 +82,18 @@ impl VecZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> &[f64] {
|
||||
let ptr: *mut f64 = self.ptr as *mut f64;
|
||||
let size: usize = self.n() * self.poly_count();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, size) }
|
||||
/// Returns a non-mutable pointer to the backedn slice of the receiver.
|
||||
pub fn as_ptr(&self) -> *const f64 {
|
||||
self.ptr as *const f64
|
||||
}
|
||||
|
||||
pub fn at(&self, col_i: usize) -> &[f64] {
|
||||
&self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()]
|
||||
/// Returns a mutable pointer to the backedn slice of the receiver.
|
||||
pub fn as_mut_ptr(&mut self) -> *mut f64 {
|
||||
self.ptr as *mut f64
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> &[f64] {
|
||||
unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
pub fn raw_mut(&mut self) -> &mut [f64] {
|
||||
@@ -98,10 +102,54 @@ impl VecZnxDft<FFT64> {
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, size) }
|
||||
}
|
||||
|
||||
pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] {
|
||||
let n: usize = self.n();
|
||||
let limbs:usize = self.limbs();
|
||||
&mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs]
|
||||
pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.limbs());
|
||||
}
|
||||
let offset: usize = self.n * (j * self.cols() + i);
|
||||
self.as_ptr().wrapping_add(offset)
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the i-th limb.
|
||||
/// The returned array is of size [Self::n()] * [Self::cols()].
|
||||
pub fn at_limb(&self, i: usize) -> &[f64] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) }
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the (i, j)-th poly.
|
||||
/// The returned array is of size [Self::n()].
|
||||
pub fn at_poly(&self, i: usize, j: usize) -> &[f64] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) }
|
||||
}
|
||||
|
||||
/// Returns a mutable pointer starting a the (i, j)-th small poly.
|
||||
pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.limbs());
|
||||
}
|
||||
|
||||
let offset: usize = self.n * (j * self.cols() + i);
|
||||
self.as_mut_ptr().wrapping_add(offset)
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the i-th limb.
|
||||
/// The returned array is of size [Self::n()] * [Self::cols()].
|
||||
pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) }
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the (i, j)-th poly.
|
||||
/// The returned array is of size [Self::n()].
|
||||
pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) }
|
||||
}
|
||||
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,6 +337,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
println!("{}", a.poly_count());
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft_automorphism(
|
||||
self.ptr,
|
||||
@@ -303,12 +352,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
|
||||
unsafe {
|
||||
std::cmp::max(
|
||||
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,11 +360,12 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
mod tests {
|
||||
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
|
||||
use itertools::izip;
|
||||
use sampling::source::{Source, new_seed};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn test_automorphism_dft() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(128);
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let limbs: usize = 2;
|
||||
let log_base2k: usize = 17;
|
||||
@@ -328,25 +373,19 @@ mod tests {
|
||||
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
|
||||
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
|
||||
|
||||
let mut source: Source = Source::new(new_seed());
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
|
||||
|
||||
// a_dft <- DFT(a)
|
||||
module.vec_znx_dft(&mut a_dft, &a);
|
||||
|
||||
|
||||
|
||||
// a_dft <- AUTO(a_dft)
|
||||
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
||||
|
||||
println!("123");
|
||||
|
||||
// a <- AUTO(a)
|
||||
module.vec_znx_automorphism_inplace(p, &mut a);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user