diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 1da44e9..5385a5b 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -34,8 +34,6 @@ fn main() { let mut a: VecZnx = module.new_vec_znx(1, limbs); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); - - // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); @@ -93,8 +91,6 @@ fn main() { // res <- normalize(buf_big) module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); - - // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 5144afd..83c937a 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -106,6 +106,14 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } +// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is +// at least as bit as DEFAULTALIGN / std::mem::size_of::(). pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::(size, DEFAULTALIGN) + alloc_aligned_custom::( + std::cmp::max( + size.next_power_of_two(), + DEFAULTALIGN / std::mem::size_of::(), + ), + DEFAULTALIGN, + ) } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index aff1ce9..a6d5858 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -478,13 +478,13 @@ impl VecZnxOps for Module { self.ptr, c.as_mut_ptr(), c.limbs() as u64, - (n * c.limbs()) as u64, + (n * c.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -502,13 +502,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -527,13 +527,13 @@ impl VecZnxOps for Module { self.ptr, c.as_mut_ptr(), c.limbs() as u64, - (n * c.limbs()) as u64, + (n * c.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -551,13 +551,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, ) } } @@ -575,13 +575,13 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, b.as_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -598,10 +598,10 @@ impl VecZnxOps for Module { self.ptr, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -617,10 +617,10 @@ impl VecZnxOps for Module { self.ptr, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -638,10 +638,10 @@ impl VecZnxOps for Module { k, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -658,10 +658,10 @@ impl VecZnxOps for Module { k, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ) } } @@ -691,10 +691,10 @@ impl VecZnxOps for Module { k, b.as_mut_ptr(), b.limbs() as u64, - (n * b.limbs()) as u64, + (n * b.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ); } } @@ -722,10 +722,10 @@ impl VecZnxOps for Module { k, a.as_mut_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, a.as_ptr(), a.limbs() as u64, - (n * a.limbs()) as u64, + (n * a.cols()) as u64, ); } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index b19f126..a7bdd59 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -73,7 +73,13 @@ impl VecZnxBig { // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { let raw: &[i64] = self.raw(); - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &raw[i * self.n() * self.cols()..i * self.n() * self.cols()+n])) + (0..self.limbs()).for_each(|i| { + println!( + "{}: {:?}", + i, + &raw[i * self.n() * self.cols()..i * self.n() * self.cols() + n] + ) + }) } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index ec4067f..61c2a85 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -82,14 +82,18 @@ impl VecZnxDft { } } - pub fn raw(&self) -> &[f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { &std::slice::from_raw_parts(ptr, size) } + /// Returns a non-mutable pointer to the backedn slice of the receiver. + pub fn as_ptr(&self) -> *const f64 { + self.ptr as *const f64 } - pub fn at(&self, col_i: usize) -> &[f64] { - &self.raw()[col_i * self.n() * self.limbs()..(col_i + 1) * self.n() * self.limbs()] + /// Returns a mutable pointer to the backedn slice of the receiver. + pub fn as_mut_ptr(&mut self) -> *mut f64 { + self.ptr as *mut f64 + } + + pub fn raw(&self) -> &[f64] { + unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } pub fn raw_mut(&mut self) -> &mut [f64] { @@ -98,10 +102,54 @@ impl VecZnxDft { unsafe { std::slice::from_raw_parts_mut(ptr, size) } } - pub fn at_mut(&mut self, col_i: usize) -> &mut [f64] { - let n: usize = self.n(); - let limbs:usize = self.limbs(); - &mut self.raw_mut()[col_i * n * limbs..(col_i + 1) * n * limbs] + pub fn at_ptr(&self, i: usize, j: usize) -> *const f64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset: usize = self.n * (j * self.cols() + i); + self.as_ptr().wrapping_add(offset) + } + + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[f64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly(&self, i: usize, j: usize) -> &[f64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } + } + + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut f64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + + let offset: usize = self.n * (j * self.cols() + i); + self.as_mut_ptr().wrapping_add(offset) + } + + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [f64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [f64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } + } + + pub fn print(&self, n: usize) { + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } @@ -289,6 +337,7 @@ impl VecZnxDftOps for Module { ); assert_alignement(tmp_bytes.as_ptr()) } + println!("{}", a.poly_count()); unsafe { vec_znx_dft::vec_znx_dft_automorphism( self.ptr, @@ -303,12 +352,7 @@ impl VecZnxDftOps for Module { } fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { - unsafe { - std::cmp::max( - vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize, - DEFAULTALIGN, - ) - } + unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize } } } @@ -316,11 +360,12 @@ impl VecZnxDftOps for Module { mod tests { use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; use itertools::izip; - use sampling::source::{Source, new_seed}; + use sampling::source::Source; #[test] fn test_automorphism_dft() { - let module: Module = Module::::new(128); + let n: usize = 8; + let module: Module = Module::::new(n); let limbs: usize = 2; let log_base2k: usize = 17; @@ -328,25 +373,19 @@ mod tests { let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); - let mut source: Source = Source::new(new_seed()); + let mut source: Source = Source::new([0u8; 32]); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let p: i64 = -5; - - // a_dft <- DFT(a) module.vec_znx_dft(&mut a_dft, &a); - - // a_dft <- AUTO(a_dft) module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); - println!("123"); - // a <- AUTO(a) module.vec_znx_automorphism_inplace(p, &mut a); diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 05dd027..f868a06 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -53,7 +53,6 @@ impl Infos for VmpPMat { } impl VmpPMat { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> VmpPMat { let mut data: Vec = alloc_aligned::(module.bytes_of_vmp_pmat(rows, cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); @@ -352,21 +351,19 @@ pub trait VmpPMatOps { } impl VmpPMatOps for Module { - fn new_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> VmpPMat { VmpPMat::::new(self, rows, cols, limbs) } fn bytes_of_vmp_pmat(&self, rows: usize, cols: usize, limbs: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs* cols) as u64) as usize } + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } } fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } + unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } } fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); @@ -379,7 +376,7 @@ impl VmpPMatOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - (b.limbs()*b.cols()) as u64, + (b.limbs() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -387,7 +384,7 @@ impl VmpPMatOps for Module { fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] - { + { assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert_alignement(tmp_bytes.as_ptr()); @@ -399,7 +396,7 @@ impl VmpPMatOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - (b.limbs()*b.cols()) as u64, + (b.limbs() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -419,7 +416,7 @@ impl VmpPMatOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - (a.limbs()*a.cols()) as u64, + (a.limbs() * a.cols()) as u64, ); } }