improved alligned vec allocation & fixed vec_znx calls, fixed auto dft test

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 11:23:47 +02:00
parent 2a96f89047
commit 82082db727
6 changed files with 113 additions and 67 deletions

View File

@@ -82,14 +82,18 @@ impl VecZnxDft<FFT64> {
}
}
pub fn raw(&self) -> &[f64] {
let ptr: *mut f64 = self.ptr as *mut f64;
let size: usize = self.n() * self.poly_count();
unsafe { &std::slice::from_raw_parts(ptr, size) }
/// Returns a non-mutable pointer to the backedn slice of the receiver.
pub fn as_ptr(&self) -> *const f64 {
self.ptr as *const f64
}
pub fn at(&self, col_i: usize) -> &[f64] {
&self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()]
/// Returns a mutable pointer to the backedn slice of the receiver.
pub fn as_mut_ptr(&mut self) -> *mut f64 {
self.ptr as *mut f64
}
pub fn raw(&self) -> &[f64] {
unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
}
pub fn raw_mut(&mut self) -> &mut [f64] {
@@ -98,10 +102,54 @@ impl VecZnxDft<FFT64> {
unsafe { std::slice::from_raw_parts_mut(ptr, size) }
}
pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] {
let n: usize = self.n();
let limbs:usize = self.limbs();
&mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs]
pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.as_ptr().wrapping_add(offset)
}
/// Returns a non-mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb(&self, i: usize) -> &[f64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) }
}
/// Returns a non-mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly(&self, i: usize, j: usize) -> &[f64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) }
}
/// Returns a mutable pointer starting a the (i, j)-th small poly.
pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.as_mut_ptr().wrapping_add(offset)
}
/// Returns a mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) }
}
/// Returns a mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) }
}
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
}
}
@@ -289,6 +337,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
);
assert_alignement(tmp_bytes.as_ptr())
}
println!("{}", a.poly_count());
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
@@ -303,12 +352,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
}
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe {
std::cmp::max(
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
DEFAULTALIGN,
)
}
unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
}
}
@@ -316,11 +360,12 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
mod tests {
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
use itertools::izip;
use sampling::source::{Source, new_seed};
use sampling::source::Source;
#[test]
fn test_automorphism_dft() {
let module: Module<FFT64> = Module::<FFT64>::new(128);
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let limbs: usize = 2;
let log_base2k: usize = 17;
@@ -328,25 +373,19 @@ mod tests {
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
let mut source: Source = Source::new(new_seed());
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
println!("123");
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a);