Ensures allocated memory is initialized

This commit is contained in:
Jean-Philippe Bossuat
2025-02-25 13:23:18 +01:00
parent e4f4194945
commit 871b85e471
7 changed files with 135 additions and 70 deletions

View File

@@ -412,6 +412,8 @@ impl VmpPMatOps for Module {
}
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
unsafe {
vmp::vmp_prepare_contiguous(
self.0,
@@ -426,6 +428,14 @@ impl VmpPMatOps for Module {
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
#[cfg(debug_assertions)]
{
debug_assert_eq!(a.len(), b.rows);
a.iter().for_each(|ai| {
debug_assert_eq!(ai.len(), b.n * b.cols);
});
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
}
unsafe {
vmp::vmp_prepare_dblptr(
self.0,
@@ -439,7 +449,8 @@ impl VmpPMatOps for Module {
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
debug_assert!(a.len() == b.cols() * self.n());
debug_assert_eq!(a.len(), b.cols() * self.n());
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
unsafe {
vmp::vmp_prepare_row(
self.0,
@@ -478,6 +489,9 @@ impl VmpPMatOps for Module {
b: &VmpPMat,
buf: &mut [u8],
) {
debug_assert!(
buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
unsafe {
vmp::vmp_apply_dft(
self.0,
@@ -513,6 +527,10 @@ impl VmpPMatOps for Module {
}
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) {
debug_assert!(
buf.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
unsafe {
vmp::vmp_apply_dft_to_dft(
self.0,
@@ -529,6 +547,10 @@ impl VmpPMatOps for Module {
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
debug_assert!(
buf.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
);
unsafe {
vmp::vmp_apply_dft_to_dft(
self.0,