improved alligned vec allocation & fixed vec_znx calls, fixed auto dft test

This commit is contained in:
Jean-Philippe Bossuat
2025-04-26 11:23:47 +02:00
parent 2a96f89047
commit 82082db727
6 changed files with 113 additions and 67 deletions

View File

@@ -34,8 +34,6 @@ fn main() {
let mut a: VecZnx = module.new_vec_znx(1, limbs); let mut a: VecZnx = module.new_vec_znx(1, limbs);
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.limbs()); let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.limbs());
@@ -93,8 +91,6 @@ fn main() {
// res <- normalize(buf_big) // res <- normalize(buf_big)
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);
// have = m * 2^{log_scale} + e // have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have);

View File

@@ -106,6 +106,14 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
unsafe { Vec::from_raw_parts(ptr, len, cap) } unsafe { Vec::from_raw_parts(ptr, len, cap) }
} }
// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is
// at least as bit as DEFAULTALIGN / std::mem::size_of::<T>().
pub fn alloc_aligned<T>(size: usize) -> Vec<T> { pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size, DEFAULTALIGN) alloc_aligned_custom::<T>(
std::cmp::max(
size.next_power_of_two(),
DEFAULTALIGN / std::mem::size_of::<T>(),
),
DEFAULTALIGN,
)
} }

View File

@@ -478,13 +478,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
c.as_mut_ptr(), c.as_mut_ptr(),
c.limbs() as u64, c.limbs() as u64,
(n * c.limbs()) as u64, (n * c.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
b.as_ptr(), b.as_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
) )
} }
} }
@@ -502,13 +502,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
b.as_ptr(), b.as_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
) )
} }
} }
@@ -527,13 +527,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
c.as_mut_ptr(), c.as_mut_ptr(),
c.limbs() as u64, c.limbs() as u64,
(n * c.limbs()) as u64, (n * c.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
b.as_ptr(), b.as_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
) )
} }
} }
@@ -551,13 +551,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
b.as_ptr(), b.as_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
) )
} }
} }
@@ -575,13 +575,13 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
b.as_ptr(), b.as_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
) )
} }
} }
@@ -598,10 +598,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
) )
} }
} }
@@ -617,10 +617,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
self.ptr, self.ptr,
a.as_mut_ptr(), a.as_mut_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
) )
} }
} }
@@ -638,10 +638,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
k, k,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
) )
} }
} }
@@ -658,10 +658,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
k, k,
a.as_mut_ptr(), a.as_mut_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
) )
} }
} }
@@ -691,10 +691,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
k, k,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.limbs() as u64,
(n * b.limbs()) as u64, (n * b.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
); );
} }
} }
@@ -722,10 +722,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
k, k,
a.as_mut_ptr(), a.as_mut_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.limbs() as u64,
(n * a.limbs()) as u64, (n * a.cols()) as u64,
); );
} }
} }

View File

@@ -73,7 +73,13 @@ impl VecZnxBig<FFT64> {
// Prints the first `n` coefficients of each limb // Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) { pub fn print(&self, n: usize) {
let raw: &[i64] = self.raw(); let raw: &[i64] = self.raw();
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &raw[i * self.n() * self.cols()..i * self.n() * self.cols()+n])) (0..self.limbs()).for_each(|i| {
println!(
"{}: {:?}",
i,
&raw[i * self.n() * self.cols()..i * self.n() * self.cols() + n]
)
})
} }
} }

View File

@@ -82,14 +82,18 @@ impl VecZnxDft<FFT64> {
} }
} }
pub fn raw(&self) -> &[f64] { /// Returns a non-mutable pointer to the backedn slice of the receiver.
let ptr: *mut f64 = self.ptr as *mut f64; pub fn as_ptr(&self) -> *const f64 {
let size: usize = self.n() * self.poly_count(); self.ptr as *const f64
unsafe { &std::slice::from_raw_parts(ptr, size) }
} }
pub fn at(&self, col_i: usize) -> &[f64] { /// Returns a mutable pointer to the backedn slice of the receiver.
&self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()] pub fn as_mut_ptr(&mut self) -> *mut f64 {
self.ptr as *mut f64
}
pub fn raw(&self) -> &[f64] {
unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
} }
pub fn raw_mut(&mut self) -> &mut [f64] { pub fn raw_mut(&mut self) -> &mut [f64] {
@@ -98,10 +102,54 @@ impl VecZnxDft<FFT64> {
unsafe { std::slice::from_raw_parts_mut(ptr, size) } unsafe { std::slice::from_raw_parts_mut(ptr, size) }
} }
pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] { pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 {
let n: usize = self.n(); #[cfg(debug_assertions)]
let limbs:usize = self.limbs(); {
&mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs] assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.as_ptr().wrapping_add(offset)
}
/// Returns a non-mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb(&self, i: usize) -> &[f64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) }
}
/// Returns a non-mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly(&self, i: usize, j: usize) -> &[f64] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) }
}
/// Returns a mutable pointer starting a the (i, j)-th small poly.
pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.limbs());
}
let offset: usize = self.n * (j * self.cols() + i);
self.as_mut_ptr().wrapping_add(offset)
}
/// Returns a mutable reference to the i-th limb.
/// The returned array is of size [Self::n()] * [Self::cols()].
pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) }
}
/// Returns a mutable reference to the (i, j)-th poly.
/// The returned array is of size [Self::n()].
pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) }
}
pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
} }
} }
@@ -289,6 +337,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
); );
assert_alignement(tmp_bytes.as_ptr()) assert_alignement(tmp_bytes.as_ptr())
} }
println!("{}", a.poly_count());
unsafe { unsafe {
vec_znx_dft::vec_znx_dft_automorphism( vec_znx_dft::vec_znx_dft_automorphism(
self.ptr, self.ptr,
@@ -303,12 +352,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
} }
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe { unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
std::cmp::max(
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
DEFAULTALIGN,
)
}
} }
} }
@@ -316,11 +360,12 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
mod tests { mod tests {
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
use itertools::izip; use itertools::izip;
use sampling::source::{Source, new_seed}; use sampling::source::Source;
#[test] #[test]
fn test_automorphism_dft() { fn test_automorphism_dft() {
let module: Module<FFT64> = Module::<FFT64>::new(128); let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let limbs: usize = 2; let limbs: usize = 2;
let log_base2k: usize = 17; let log_base2k: usize = 17;
@@ -328,25 +373,19 @@ mod tests {
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs); let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs); let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
let mut source: Source = Source::new(new_seed()); let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5; let p: i64 = -5;
// a_dft <- DFT(a) // a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a); module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft) // a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
println!("123");
// a <- AUTO(a) // a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a); module.vec_znx_automorphism_inplace(p, &mut a);

View File

@@ -53,7 +53,6 @@ impl<B: Backend> Infos for VmpPMat<B> {
} }
impl VmpPMat<FFT64> { impl VmpPMat<FFT64> {
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> { fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_vmp_pmat(rows, cols, limbs)); let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_vmp_pmat(rows, cols, limbs));
let ptr: *mut u8 = data.as_mut_ptr(); let ptr: *mut u8 = data.as_mut_ptr();
@@ -352,21 +351,19 @@ pub trait VmpPMatOps<B: Backend> {
} }
impl VmpPMatOps<FFT64> for Module<FFT64> { impl VmpPMatOps<FFT64> for Module<FFT64> {
fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> { fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat<FFT64> {
VmpPMat::<FFT64>::new(self, rows, cols, limbs) VmpPMat::<FFT64>::new(self, rows, cols, limbs)
} }
fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize } unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize }
} }
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize }
} }
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat<FFT64>, a: &[i64], tmp_bytes: &mut [u8]) { fn vmp_prepare_contiguous(&self, b: &mut VmpPMat<FFT64>, a: &[i64], tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.len(), b.n() * b.poly_count()); assert_eq!(a.len(), b.n() * b.poly_count());
@@ -379,7 +376,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
b.as_mut_ptr() as *mut vmp_pmat_t, b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(), a.as_ptr(),
b.rows() as u64, b.rows() as u64,
(b.limbs()*b.cols()) as u64, (b.limbs() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -399,7 +396,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
a.as_ptr(), a.as_ptr(),
row_i as u64, row_i as u64,
b.rows() as u64, b.rows() as u64,
(b.limbs()*b.cols()) as u64, (b.limbs() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -419,7 +416,7 @@ impl VmpPMatOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
row_i as u64, row_i as u64,
a.rows() as u64, a.rows() as u64,
(a.limbs()*a.cols()) as u64, (a.limbs() * a.cols()) as u64,
); );
} }
} }