Ensures allocated memory is initialized

This commit is contained in:
Jean-Philippe Bossuat
2025-02-25 13:23:18 +01:00
parent e4f4194945
commit 871b85e471
7 changed files with 135 additions and 70 deletions

View File

@@ -1,7 +1,7 @@
use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::{Infos, Module, VecZnxApi, VecZnxBig};
use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
@@ -9,8 +9,12 @@ impl VecZnxDft {
/// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxDft {
VecZnxDft(data.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, cols)
pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
VecZnxDft(
tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t,
cols,
)
}
/// Cast a [VecZnxDft] into a [VecZnxBig].
@@ -73,14 +77,15 @@ impl VecZnxDftOps for Module {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) }
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
assert!(
bytes.len() >= <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols),
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!(
tmp_bytes.len() >= <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
tmp_bytes.len(),
<Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols)
);
VecZnxDft::from_bytes(cols, bytes)
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
VecZnxDft::from_bytes(cols, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
@@ -88,7 +93,7 @@ impl VecZnxDftOps for Module {
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
assert!(
debug_assert!(
b.cols() >= a_limbs,
"invalid c_vector: b_vector.cols()={} < a_limbs={}",
b.cols(),
@@ -108,7 +113,7 @@ impl VecZnxDftOps for Module {
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft<T: VecZnxApi + Infos>(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) {
assert!(
debug_assert!(
b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}",
b.cols(),
@@ -134,24 +139,25 @@ impl VecZnxDftOps for Module {
a_cols: usize,
tmp_bytes: &mut [u8],
) {
assert!(
debug_assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
debug_assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
debug_assert!(
tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
);
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
unsafe {
vec_znx_dft::vec_znx_idft(
self.0,