diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 51840ad..af32db2 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -127,7 +127,7 @@ fn encode_vec_i64( ) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!( + debug_assert!( cols <= a.cols(), "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", cols, @@ -177,7 +177,7 @@ fn encode_vec_i64( fn decode_vec_i64(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!( + debug_assert!( data.len() >= a.n(), "invalid data: data.len()={} < a.n()={}", data.len(), @@ -201,7 +201,7 @@ fn decode_vec_i64(a: &T, log_base2k: usize, log_k: usize, data: fn decode_vec_float(a: &T, log_base2k: usize, data: &mut [Float]) { let cols: usize = a.cols(); - assert!( + debug_assert!( data.len() >= a.n(), "invalid data: data.len()={} < a.n()={}", data.len(), @@ -237,9 +237,9 @@ fn encode_coeff_i64( value: i64, log_max: usize, ) { - assert!(i < a.n()); + debug_assert!(i < a.n()); let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!( + debug_assert!( cols <= a.cols(), "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", cols, @@ -281,7 +281,7 @@ fn encode_coeff_i64( fn decode_coeff_i64(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 { let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!(i < a.n()); + debug_assert!(i < a.n()); let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index c69243e..e003644 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -33,12 +33,16 @@ pub use vec_znx_dft::*; pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; +pub const DEFAULTALIGN: usize = 64; -#[allow(dead_code)] -fn is_aligned(ptr: *const T, align: usize) -> bool { +fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } +fn is_aligned(ptr: *const T) -> bool { + is_aligned_custom(ptr, DEFAULTALIGN) +} + pub fn cast(data: &[T]) -> &[V] { let ptr: *const V = data.as_ptr() as *const V; let len: usize = data.len() / std::mem::size_of::(); @@ -52,12 +56,15 @@ pub fn cast_mut(data: &[T]) -> &mut [V] { } use std::alloc::{alloc, Layout}; +use std::ptr; -pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec { - assert_eq!( - align & (align - 1), - 0, - "align={} must be a power of two", +/// Allocates a block of bytes with a custom alignement. +/// Alignement must be a power of two and size a multiple of the alignement. +/// Allocated memory is initialized to zero. +pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { + assert!( + align.is_power_of_two(), + "Alignment must be a power of two but is {}", align ); assert_eq!( @@ -73,11 +80,28 @@ pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec { if ptr.is_null() { panic!("Memory allocation failed"); } + assert!( + is_aligned_custom(ptr, align), + "Memory allocation at {:p} is not aligned to {} bytes", + ptr, + align + ); + // Init allocated memory to zero + ptr::write_bytes(ptr, 0, size); Vec::from_raw_parts(ptr, size, size) } } -pub fn alloc_aligned(size: usize, align: usize) -> Vec { +/// Allocates a block of bytes aligned with [DEFAULTALIGN]. +/// Size must be amultiple of [DEFAULTALIGN]. +/// /// Allocated memory is initialized to zero. +pub fn alloc_aligned_u8(size: usize) -> Vec { + alloc_aligned_custom_u8(size, DEFAULTALIGN) +} + +/// Allocates a block of T aligned with [DEFAULTALIGN]. +/// Size of T * size msut be a multiple of [DEFAULTALIGN]. +pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( (size * std::mem::size_of::()) % align, 0, @@ -85,7 +109,7 @@ pub fn alloc_aligned(size: usize, align: usize) -> Vec { size, align ); - let mut vec_u8: Vec = alloc_aligned_u8(std::mem::size_of::() * size, align); + let mut vec_u8: Vec = alloc_aligned_custom_u8(std::mem::size_of::() * size, align); let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; let len: usize = vec_u8.len() / std::mem::size_of::(); let cap: usize = vec_u8.capacity() / std::mem::size_of::(); @@ -93,6 +117,10 @@ pub fn alloc_aligned(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } +pub fn alloc_aligned(size: usize) -> Vec { + alloc_aligned_custom::(size, DEFAULTALIGN) +} + fn alias_mut_slice_to_vec(slice: &[T]) -> Vec { unsafe { let ptr: *mut T = slice.as_ptr() as *mut T; diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index 9bfe2ef..8b329c0 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,5 +1,5 @@ -use crate::ffi::svp; -use crate::{alias_mut_slice_to_vec, Module, VecZnxApi, VecZnxDft}; +use crate::ffi::svp::{self, bytes_of_svp_ppol}; +use crate::{alias_mut_slice_to_vec, is_aligned, Module, VecZnxApi, VecZnxDft}; use crate::{alloc_aligned, cast, Infos}; use rand::seq::SliceRandom; @@ -17,7 +17,7 @@ impl Module { impl Scalar { pub fn new(n: usize) -> Self { - Self(alloc_aligned::(n, 64)) + Self(alloc_aligned::(n)) } pub fn n(&self) -> usize { @@ -30,13 +30,14 @@ impl Scalar { pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) { let size: usize = Self::buffer_size(n); - assert!( + debug_assert!( buf.len() >= size, "invalid buffer: buf.len()={} < self.buffer_size(n={})={}", buf.len(), n, size ); + debug_assert!(is_aligned(buf.as_ptr())); self.0 = alias_mut_slice_to_vec(cast::(&buf[..size])) } @@ -74,6 +75,8 @@ impl SvpPPol { } pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol { + debug_assert!(is_aligned(bytes.as_ptr())); + debug_assert!(bytes.len() << 3 >= size); SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size) } @@ -125,7 +128,7 @@ impl SvpPPolOps for Module { b: &T, b_cols: usize, ) { - assert!( + debug_assert!( c.cols() >= b_cols, "invalid c_vector: c_vector.cols()={} < b.cols()={}", c.cols(), diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index e615429..642d850 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -2,6 +2,7 @@ use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; use crate::ffi::znx::znx_zero_i64_ref; +use crate::is_aligned; use crate::{alias_mut_slice_to_vec, alloc_aligned}; use crate::{Infos, Module}; use itertools::izip; @@ -128,7 +129,7 @@ impl VecZnxApi for VecZnxBorrow { /// the size of data is at least equal to [VecZnx::bytes_of]. fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { let size = Self::bytes_of(n, cols); - assert!( + debug_assert!( bytes.len() >= size, "invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}", bytes.len(), @@ -136,6 +137,7 @@ impl VecZnxApi for VecZnxBorrow { cols, size ); + debug_assert!(is_aligned(bytes.as_ptr())); VecZnxBorrow { n: n, cols: cols, @@ -225,20 +227,20 @@ impl VecZnxApi for VecZnx { /// /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [VecZnx::bytes_of]. - fn from_bytes(n: usize, cols: usize, buf: &mut [u8]) -> Self::Owned { + fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { let size = Self::bytes_of(n, cols); - assert!( - buf.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}", - buf.len(), + debug_assert!( + bytes.len() >= size, + "invalid bytes: bytes.len()={} < self.bytes_of(n={}, cols={})={}", + bytes.len(), n, cols, size ); - + debug_assert!(is_aligned(bytes.as_ptr())); VecZnx { n: n, - data: alias_mut_slice_to_vec(cast_mut(&mut buf[..size])), + data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])), } } @@ -348,7 +350,7 @@ impl VecZnx { pub fn new(n: usize, cols: usize) -> Self { Self { n: n, - data: alloc_aligned::(n * cols, 64), + data: alloc_aligned::(n * cols), } } @@ -399,17 +401,18 @@ pub fn switch_degree(b: &mut B, a: &A) { }); } -fn normalize(log_base2k: usize, a: &mut T, carry: &mut [u8]) { +fn normalize(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - assert!( - carry.len() >= n * 8, - "invalid carry: carry.len()={} < self.n()={}", - carry.len(), + debug_assert!( + tmp_bytes.len() >= n * 8, + "invalid tmp_bytes: tmp_bytes.len()={} < self.n()={}", + tmp_bytes.len(), n ); + debug_assert!(is_aligned(tmp_bytes.as_ptr())); - let carry_i64: &mut [i64] = cast_mut(carry); + let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); @@ -426,16 +429,18 @@ fn normalize(log_base2k: usize, a: &mut T, carry: &mut [u8]) { } } -pub fn rsh(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) { +pub fn rsh(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - assert!( - carry.len() >> 3 >= n, + debug_assert!( + tmp_bytes.len() >> 3 >= n, "invalid carry: carry.len()/8={} < self.n()={}", - carry.len() >> 3, + tmp_bytes.len() >> 3, n ); + debug_assert!(is_aligned(tmp_bytes.as_ptr())); + let cols: usize = a.cols(); let cols_steps: usize = k / log_base2k; @@ -447,7 +452,7 @@ pub fn rsh(log_base2k: usize, a: &mut T, k: usize, carry: &mut let k_rem = k % log_base2k; if k_rem != 0 { - let carry_i64: &mut [i64] = cast_mut(carry); + let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); @@ -469,7 +474,6 @@ pub fn rsh(log_base2k: usize, a: &mut T, k: usize, carry: &mut pub trait VecZnxCommon: VecZnxApi + Infos {} pub trait VecZnxOps { - /// Allocates a new [VecZnx]. /// /// # Arguments @@ -560,10 +564,8 @@ impl VecZnxOps for Module { self.n() * cols * 8 } - fn vec_znx_normalize_tmp_bytes(&self) -> usize{ - unsafe{ - vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize - } + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize } } // c <- a + b @@ -750,9 +752,9 @@ impl VecZnxOps for Module { a: &A, a_cols: usize, ) { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert!(a.cols() >= a_cols); + debug_assert_eq!(a.n(), self.n()); + debug_assert_eq!(b.n(), self.n()); + debug_assert!(a.cols() >= a_cols); unsafe { vec_znx::vec_znx_automorphism( self.0, @@ -803,8 +805,8 @@ impl VecZnxOps for Module { /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// ``` fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_cols: usize) { - assert_eq!(a.n(), self.n()); - assert!(a.cols() >= a_cols); + debug_assert_eq!(a.n(), self.n()); + debug_assert!(a.cols() >= a_cols); unsafe { vec_znx::vec_znx_automorphism( self.0, @@ -827,12 +829,12 @@ impl VecZnxOps for Module { ) { let (n_in, n_out) = (a.n(), b[0].n()); - assert!( + debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); b[1..].iter().for_each(|bi| { - assert_eq!( + debug_assert_eq!( bi.n(), n_out, "invalid input a: all VecZnx must have the same degree" @@ -853,12 +855,12 @@ impl VecZnxOps for Module { fn vec_znx_merge(&self, b: &mut B, a: &Vec) { let (n_in, n_out) = (b.n(), a[0].n()); - assert!( + debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); a[1..].iter().for_each(|ai| { - assert_eq!( + debug_assert_eq!( ai.n(), n_out, "invalid input a: all VecZnx must have the same degree" diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index c837240..9b24d8c 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; -use crate::{Infos, Module, VecZnxApi, VecZnxDft}; +use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxDft}; pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); @@ -9,6 +9,7 @@ impl VecZnxBig { /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxBig { + debug_assert!(is_aligned(data.as_ptr())); VecZnxBig( data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t, cols, @@ -94,12 +95,13 @@ impl VecZnxBigOps for Module { } fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { - assert!( + debug_assert!( bytes.len() >= ::bytes_of_vec_znx_big(self, cols), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", bytes.len(), ::bytes_of_vec_znx_big(self, cols) ); + debug_assert!(is_aligned(bytes.as_ptr())); VecZnxBig::from_bytes(cols, bytes) } @@ -189,6 +191,7 @@ impl VecZnxBigOps for Module { tmp_bytes.len(), ::vec_znx_big_normalize_tmp_bytes(self) ); + debug_assert!(is_aligned(tmp_bytes.as_ptr())); unsafe { vec_znx_big::vec_znx_big_normalize_base2k( self.0, @@ -223,6 +226,7 @@ impl VecZnxBigOps for Module { tmp_bytes.len(), ::vec_znx_big_range_normalize_base2k_tmp_bytes(self) ); + debug_assert!(is_aligned(tmp_bytes.as_ptr())); unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k( self.0, diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 9d7ad0e..8cc7a42 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; -use crate::{Infos, Module, VecZnxApi, VecZnxBig}; +use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxBig}; pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); @@ -9,8 +9,12 @@ impl VecZnxDft { /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxDft { - VecZnxDft(data.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, cols) + pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + debug_assert!(is_aligned(tmp_bytes.as_ptr())); + VecZnxDft( + tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, + cols, + ) } /// Cast a [VecZnxDft] into a [VecZnxBig]. @@ -73,14 +77,15 @@ impl VecZnxDftOps for Module { unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) } } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft { - assert!( - bytes.len() >= ::bytes_of_vec_znx_dft(self, cols), + fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + debug_assert!( + tmp_bytes.len() >= ::bytes_of_vec_znx_dft(self, cols), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", - bytes.len(), + tmp_bytes.len(), ::bytes_of_vec_znx_dft(self, cols) ); - VecZnxDft::from_bytes(cols, bytes) + debug_assert!(is_aligned(tmp_bytes.as_ptr())); + VecZnxDft::from_bytes(cols, tmp_bytes) } fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { @@ -88,7 +93,7 @@ impl VecZnxDftOps for Module { } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { - assert!( + debug_assert!( b.cols() >= a_limbs, "invalid c_vector: b_vector.cols()={} < a_limbs={}", b.cols(), @@ -108,7 +113,7 @@ impl VecZnxDftOps for Module { /// # Panics /// If b.cols < a_cols fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) { - assert!( + debug_assert!( b.cols() >= a_cols, "invalid a_cols: b.cols()={} < a_cols={}", b.cols(), @@ -134,24 +139,25 @@ impl VecZnxDftOps for Module { a_cols: usize, tmp_bytes: &mut [u8], ) { - assert!( + debug_assert!( b.cols() >= a_cols, "invalid c_vector: b.cols()={} < a_cols={}", b.cols(), a_cols ); - assert!( + debug_assert!( a.cols() >= a_cols, "invalid c_vector: a.cols()={} < a_cols={}", a.cols(), a_cols ); - assert!( + debug_assert!( tmp_bytes.len() <= ::vec_znx_idft_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", tmp_bytes.len(), ::vec_znx_idft_tmp_bytes(self) ); + debug_assert!(is_aligned(tmp_bytes.as_ptr())); unsafe { vec_znx_dft::vec_znx_idft( self.0, diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index d572e11..52459fe 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -412,6 +412,8 @@ impl VmpPMatOps for Module { } fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { + debug_assert_eq!(a.len(), b.n * b.rows * b.cols); + debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); unsafe { vmp::vmp_prepare_contiguous( self.0, @@ -426,6 +428,14 @@ impl VmpPMatOps for Module { fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) { let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); + #[cfg(debug_assertions)] + { + debug_assert_eq!(a.len(), b.rows); + a.iter().for_each(|ai| { + debug_assert_eq!(ai.len(), b.n * b.cols); + }); + debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); + } unsafe { vmp::vmp_prepare_dblptr( self.0, @@ -439,7 +449,8 @@ impl VmpPMatOps for Module { } fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) { - debug_assert!(a.len() == b.cols() * self.n()); + debug_assert_eq!(a.len(), b.cols() * self.n()); + debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); unsafe { vmp::vmp_prepare_row( self.0, @@ -478,6 +489,9 @@ impl VmpPMatOps for Module { b: &VmpPMat, buf: &mut [u8], ) { + debug_assert!( + buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()) + ); unsafe { vmp::vmp_apply_dft( self.0, @@ -513,6 +527,10 @@ impl VmpPMatOps for Module { } fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) { + debug_assert!( + buf.len() + >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()) + ); unsafe { vmp::vmp_apply_dft_to_dft( self.0, @@ -529,6 +547,10 @@ impl VmpPMatOps for Module { } fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { + debug_assert!( + buf.len() + >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()) + ); unsafe { vmp::vmp_apply_dft_to_dft( self.0,