Ensures allocated memory is initialized

This commit is contained in:
Jean-Philippe Bossuat
2025-02-25 13:23:18 +01:00
parent e4f4194945
commit 871b85e471
7 changed files with 135 additions and 70 deletions

View File

@@ -127,7 +127,7 @@ fn encode_vec_i64<T: VecZnxCommon>(
) { ) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
assert!( debug_assert!(
cols <= a.cols(), cols <= a.cols(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
cols, cols,
@@ -177,7 +177,7 @@ fn encode_vec_i64<T: VecZnxCommon>(
fn decode_vec_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) { fn decode_vec_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
assert!( debug_assert!(
data.len() >= a.n(), data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}", "invalid data: data.len()={} < a.n()={}",
data.len(), data.len(),
@@ -201,7 +201,7 @@ fn decode_vec_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, data:
fn decode_vec_float<T: VecZnxCommon>(a: &T, log_base2k: usize, data: &mut [Float]) { fn decode_vec_float<T: VecZnxCommon>(a: &T, log_base2k: usize, data: &mut [Float]) {
let cols: usize = a.cols(); let cols: usize = a.cols();
assert!( debug_assert!(
data.len() >= a.n(), data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}", "invalid data: data.len()={} < a.n()={}",
data.len(), data.len(),
@@ -237,9 +237,9 @@ fn encode_coeff_i64<T: VecZnxCommon>(
value: i64, value: i64,
log_max: usize, log_max: usize,
) { ) {
assert!(i < a.n()); debug_assert!(i < a.n());
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
assert!( debug_assert!(
cols <= a.cols(), cols <= a.cols(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
cols, cols,
@@ -281,7 +281,7 @@ fn encode_coeff_i64<T: VecZnxCommon>(
fn decode_coeff_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 { fn decode_coeff_i64<T: VecZnxCommon>(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 {
let cols: usize = (log_k + log_base2k - 1) / log_base2k; let cols: usize = (log_k + log_base2k - 1) / log_base2k;
assert!(i < a.n()); debug_assert!(i < a.n());
let data: &[i64] = a.raw(); let data: &[i64] = a.raw();
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = log_base2k - (log_k % log_base2k); let rem: usize = log_base2k - (log_k % log_base2k);

View File

@@ -33,12 +33,16 @@ pub use vec_znx_dft::*;
pub use vmp::*; pub use vmp::*;
pub const GALOISGENERATOR: u64 = 5; pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64;
#[allow(dead_code)] fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
(ptr as usize) % align == 0 (ptr as usize) % align == 0
} }
fn is_aligned<T>(ptr: *const T) -> bool {
is_aligned_custom(ptr, DEFAULTALIGN)
}
pub fn cast<T, V>(data: &[T]) -> &[V] { pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V; let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / std::mem::size_of::<V>(); let len: usize = data.len() / std::mem::size_of::<V>();
@@ -52,12 +56,15 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
} }
use std::alloc::{alloc, Layout}; use std::alloc::{alloc, Layout};
use std::ptr;
pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec<u8> { /// Allocates a block of bytes with a custom alignement.
assert_eq!( /// Alignement must be a power of two and size a multiple of the alignement.
align & (align - 1), /// Allocated memory is initialized to zero.
0, pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
"align={} must be a power of two", assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align align
); );
assert_eq!( assert_eq!(
@@ -73,11 +80,28 @@ pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec<u8> {
if ptr.is_null() { if ptr.is_null() {
panic!("Memory allocation failed"); panic!("Memory allocation failed");
} }
assert!(
is_aligned_custom(ptr, align),
"Memory allocation at {:p} is not aligned to {} bytes",
ptr,
align
);
// Init allocated memory to zero
ptr::write_bytes(ptr, 0, size);
Vec::from_raw_parts(ptr, size, size) Vec::from_raw_parts(ptr, size, size)
} }
} }
pub fn alloc_aligned<T>(size: usize, align: usize) -> Vec<T> { /// Allocates a block of bytes aligned with [DEFAULTALIGN].
/// Size must be amultiple of [DEFAULTALIGN].
/// /// Allocated memory is initialized to zero.
pub fn alloc_aligned_u8(size: usize) -> Vec<u8> {
alloc_aligned_custom_u8(size, DEFAULTALIGN)
}
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!( assert_eq!(
(size * std::mem::size_of::<T>()) % align, (size * std::mem::size_of::<T>()) % align,
0, 0,
@@ -85,7 +109,7 @@ pub fn alloc_aligned<T>(size: usize, align: usize) -> Vec<T> {
size, size,
align align
); );
let mut vec_u8: Vec<u8> = alloc_aligned_u8(std::mem::size_of::<T>() * size, align); let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / std::mem::size_of::<T>(); let len: usize = vec_u8.len() / std::mem::size_of::<T>();
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>(); let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
@@ -93,6 +117,10 @@ pub fn alloc_aligned<T>(size: usize, align: usize) -> Vec<T> {
unsafe { Vec::from_raw_parts(ptr, len, cap) } unsafe { Vec::from_raw_parts(ptr, len, cap) }
} }
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size, DEFAULTALIGN)
}
fn alias_mut_slice_to_vec<T>(slice: &[T]) -> Vec<T> { fn alias_mut_slice_to_vec<T>(slice: &[T]) -> Vec<T> {
unsafe { unsafe {
let ptr: *mut T = slice.as_ptr() as *mut T; let ptr: *mut T = slice.as_ptr() as *mut T;

View File

@@ -1,5 +1,5 @@
use crate::ffi::svp; use crate::ffi::svp::{self, bytes_of_svp_ppol};
use crate::{alias_mut_slice_to_vec, Module, VecZnxApi, VecZnxDft}; use crate::{alias_mut_slice_to_vec, is_aligned, Module, VecZnxApi, VecZnxDft};
use crate::{alloc_aligned, cast, Infos}; use crate::{alloc_aligned, cast, Infos};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
@@ -17,7 +17,7 @@ impl Module {
impl Scalar { impl Scalar {
pub fn new(n: usize) -> Self { pub fn new(n: usize) -> Self {
Self(alloc_aligned::<i64>(n, 64)) Self(alloc_aligned::<i64>(n))
} }
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
@@ -30,13 +30,14 @@ impl Scalar {
pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) { pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) {
let size: usize = Self::buffer_size(n); let size: usize = Self::buffer_size(n);
assert!( debug_assert!(
buf.len() >= size, buf.len() >= size,
"invalid buffer: buf.len()={} < self.buffer_size(n={})={}", "invalid buffer: buf.len()={} < self.buffer_size(n={})={}",
buf.len(), buf.len(),
n, n,
size size
); );
debug_assert!(is_aligned(buf.as_ptr()));
self.0 = alias_mut_slice_to_vec(cast::<u8, i64>(&buf[..size])) self.0 = alias_mut_slice_to_vec(cast::<u8, i64>(&buf[..size]))
} }
@@ -74,6 +75,8 @@ impl SvpPPol {
} }
pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol { pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol {
debug_assert!(is_aligned(bytes.as_ptr()));
debug_assert!(bytes.len() << 3 >= size);
SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size) SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size)
} }
@@ -125,7 +128,7 @@ impl SvpPPolOps for Module {
b: &T, b: &T,
b_cols: usize, b_cols: usize,
) { ) {
assert!( debug_assert!(
c.cols() >= b_cols, c.cols() >= b_cols,
"invalid c_vector: c_vector.cols()={} < b.cols()={}", "invalid c_vector: c_vector.cols()={} < b.cols()={}",
c.cols(), c.cols(),

View File

@@ -2,6 +2,7 @@ use crate::cast_mut;
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::ffi::znx; use crate::ffi::znx;
use crate::ffi::znx::znx_zero_i64_ref; use crate::ffi::znx::znx_zero_i64_ref;
use crate::is_aligned;
use crate::{alias_mut_slice_to_vec, alloc_aligned}; use crate::{alias_mut_slice_to_vec, alloc_aligned};
use crate::{Infos, Module}; use crate::{Infos, Module};
use itertools::izip; use itertools::izip;
@@ -128,7 +129,7 @@ impl VecZnxApi for VecZnxBorrow {
/// the size of data is at least equal to [VecZnx::bytes_of]. /// the size of data is at least equal to [VecZnx::bytes_of].
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned { fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned {
let size = Self::bytes_of(n, cols); let size = Self::bytes_of(n, cols);
assert!( debug_assert!(
bytes.len() >= size, bytes.len() >= size,
"invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}", "invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}",
bytes.len(), bytes.len(),
@@ -136,6 +137,7 @@ impl VecZnxApi for VecZnxBorrow {
cols, cols,
size size
); );
debug_assert!(is_aligned(bytes.as_ptr()));
VecZnxBorrow { VecZnxBorrow {
n: n, n: n,
cols: cols, cols: cols,
@@ -225,20 +227,20 @@ impl VecZnxApi for VecZnx {
/// ///
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [VecZnx::bytes_of]. /// the size of data is at least equal to [VecZnx::bytes_of].
fn from_bytes(n: usize, cols: usize, buf: &mut [u8]) -> Self::Owned { fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned {
let size = Self::bytes_of(n, cols); let size = Self::bytes_of(n, cols);
assert!( debug_assert!(
buf.len() >= size, bytes.len() >= size,
"invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}", "invalid bytes: bytes.len()={} < self.bytes_of(n={}, cols={})={}",
buf.len(), bytes.len(),
n, n,
cols, cols,
size size
); );
debug_assert!(is_aligned(bytes.as_ptr()));
VecZnx { VecZnx {
n: n, n: n,
data: alias_mut_slice_to_vec(cast_mut(&mut buf[..size])), data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])),
} }
} }
@@ -348,7 +350,7 @@ impl VecZnx {
pub fn new(n: usize, cols: usize) -> Self { pub fn new(n: usize, cols: usize) -> Self {
Self { Self {
n: n, n: n,
data: alloc_aligned::<i64>(n * cols, 64), data: alloc_aligned::<i64>(n * cols),
} }
} }
@@ -399,17 +401,18 @@ pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
}); });
} }
fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) { fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
assert!( debug_assert!(
carry.len() >= n * 8, tmp_bytes.len() >= n * 8,
"invalid carry: carry.len()={} < self.n()={}", "invalid tmp_bytes: tmp_bytes.len()={} < self.n()={}",
carry.len(), tmp_bytes.len(),
n n
); );
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
let carry_i64: &mut [i64] = cast_mut(carry); let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
unsafe { unsafe {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
@@ -426,16 +429,18 @@ fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
} }
} }
pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) { pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
assert!( debug_assert!(
carry.len() >> 3 >= n, tmp_bytes.len() >> 3 >= n,
"invalid carry: carry.len()/8={} < self.n()={}", "invalid carry: carry.len()/8={} < self.n()={}",
carry.len() >> 3, tmp_bytes.len() >> 3,
n n
); );
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
let cols: usize = a.cols(); let cols: usize = a.cols();
let cols_steps: usize = k / log_base2k; let cols_steps: usize = k / log_base2k;
@@ -447,7 +452,7 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut
let k_rem = k % log_base2k; let k_rem = k % log_base2k;
if k_rem != 0 { if k_rem != 0 {
let carry_i64: &mut [i64] = cast_mut(carry); let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
unsafe { unsafe {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
@@ -469,7 +474,6 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut
pub trait VecZnxCommon: VecZnxApi + Infos {} pub trait VecZnxCommon: VecZnxApi + Infos {}
pub trait VecZnxOps { pub trait VecZnxOps {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
/// # Arguments /// # Arguments
@@ -560,10 +564,8 @@ impl VecZnxOps for Module {
self.n() * cols * 8 self.n() * cols * 8
} }
fn vec_znx_normalize_tmp_bytes(&self) -> usize{ fn vec_znx_normalize_tmp_bytes(&self) -> usize {
unsafe{ unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize }
vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize
}
} }
// c <- a + b // c <- a + b
@@ -750,9 +752,9 @@ impl VecZnxOps for Module {
a: &A, a: &A,
a_cols: usize, a_cols: usize,
) { ) {
assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n()); debug_assert_eq!(b.n(), self.n());
assert!(a.cols() >= a_cols); debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.0, self.0,
@@ -803,8 +805,8 @@ impl VecZnxOps for Module {
/// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ``` /// ```
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) { fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) {
assert_eq!(a.n(), self.n()); debug_assert_eq!(a.n(), self.n());
assert!(a.cols() >= a_cols); debug_assert!(a.cols() >= a_cols);
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx::vec_znx_automorphism(
self.0, self.0,
@@ -827,12 +829,12 @@ impl VecZnxOps for Module {
) { ) {
let (n_in, n_out) = (a.n(), b[0].n()); let (n_in, n_out) = (a.n(), b[0].n());
assert!( debug_assert!(
n_out < n_in, n_out < n_in,
"invalid a: output ring degree should be smaller" "invalid a: output ring degree should be smaller"
); );
b[1..].iter().for_each(|bi| { b[1..].iter().for_each(|bi| {
assert_eq!( debug_assert_eq!(
bi.n(), bi.n(),
n_out, n_out,
"invalid input a: all VecZnx must have the same degree" "invalid input a: all VecZnx must have the same degree"
@@ -853,12 +855,12 @@ impl VecZnxOps for Module {
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) { fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) {
let (n_in, n_out) = (b.n(), a[0].n()); let (n_in, n_out) = (b.n(), a[0].n());
assert!( debug_assert!(
n_out < n_in, n_out < n_in,
"invalid a: output ring degree should be smaller" "invalid a: output ring degree should be smaller"
); );
a[1..].iter().for_each(|ai| { a[1..].iter().for_each(|ai| {
assert_eq!( debug_assert_eq!(
ai.n(), ai.n(),
n_out, n_out,
"invalid input a: all VecZnx must have the same degree" "invalid input a: all VecZnx must have the same degree"

View File

@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::{Infos, Module, VecZnxApi, VecZnxDft}; use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxDft};
pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize);
@@ -9,6 +9,7 @@ impl VecZnxBig {
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxBig { pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxBig {
debug_assert!(is_aligned(data.as_ptr()));
VecZnxBig( VecZnxBig(
data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t, data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t,
cols, cols,
@@ -94,12 +95,13 @@ impl VecZnxBigOps for Module {
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig { fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
assert!( debug_assert!(
bytes.len() >= <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols), bytes.len() >= <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(), bytes.len(),
<Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols) <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols)
); );
debug_assert!(is_aligned(bytes.as_ptr()));
VecZnxBig::from_bytes(cols, bytes) VecZnxBig::from_bytes(cols, bytes)
} }
@@ -189,6 +191,7 @@ impl VecZnxBigOps for Module {
tmp_bytes.len(), tmp_bytes.len(),
<Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self) <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self)
); );
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
unsafe { unsafe {
vec_znx_big::vec_znx_big_normalize_base2k( vec_znx_big::vec_znx_big_normalize_base2k(
self.0, self.0,
@@ -223,6 +226,7 @@ impl VecZnxBigOps for Module {
tmp_bytes.len(), tmp_bytes.len(),
<Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self) <Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
); );
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
unsafe { unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k( vec_znx_big::vec_znx_big_range_normalize_base2k(
self.0, self.0,

View File

@@ -1,7 +1,7 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::{Infos, Module, VecZnxApi, VecZnxBig}; use crate::{is_aligned, Infos, Module, VecZnxApi, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
@@ -9,8 +9,12 @@ impl VecZnxDft {
/// Returns a new [VecZnxDft] with the provided data as backing array. /// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
pub fn from_bytes(cols: usize, data: &mut [u8]) -> VecZnxDft { pub fn from_bytes(cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft(data.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, cols) debug_assert!(is_aligned(tmp_bytes.as_ptr()));
VecZnxDft(
tmp_bytes.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t,
cols,
)
} }
/// Cast a [VecZnxDft] into a [VecZnxBig]. /// Cast a [VecZnxDft] into a [VecZnxBig].
@@ -73,14 +77,15 @@ impl VecZnxDftOps for Module {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) } unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, cols as u64), cols) }
} }
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft { fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
assert!( debug_assert!(
bytes.len() >= <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols), tmp_bytes.len() >= <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(), tmp_bytes.len(),
<Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols) <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols)
); );
VecZnxDft::from_bytes(cols, bytes) debug_assert!(is_aligned(tmp_bytes.as_ptr()));
VecZnxDft::from_bytes(cols, tmp_bytes)
} }
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize { fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
@@ -88,7 +93,7 @@ impl VecZnxDftOps for Module {
} }
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
assert!( debug_assert!(
b.cols() >= a_limbs, b.cols() >= a_limbs,
"invalid c_vector: b_vector.cols()={} < a_limbs={}", "invalid c_vector: b_vector.cols()={} < a_limbs={}",
b.cols(), b.cols(),
@@ -108,7 +113,7 @@ impl VecZnxDftOps for Module {
/// # Panics /// # Panics
/// If b.cols < a_cols /// If b.cols < a_cols
fn vec_znx_dft<T: VecZnxApi + Infos>(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) { fn vec_znx_dft<T: VecZnxApi + Infos>(&self, b: &mut VecZnxDft, a: &T, a_cols: usize) {
assert!( debug_assert!(
b.cols() >= a_cols, b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}", "invalid a_cols: b.cols()={} < a_cols={}",
b.cols(), b.cols(),
@@ -134,24 +139,25 @@ impl VecZnxDftOps for Module {
a_cols: usize, a_cols: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
assert!( debug_assert!(
b.cols() >= a_cols, b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}", "invalid c_vector: b.cols()={} < a_cols={}",
b.cols(), b.cols(),
a_cols a_cols
); );
assert!( debug_assert!(
a.cols() >= a_cols, a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}", "invalid c_vector: a.cols()={} < a_cols={}",
a.cols(), a.cols(),
a_cols a_cols
); );
assert!( debug_assert!(
tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self), tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(), tmp_bytes.len(),
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self) <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
); );
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
unsafe { unsafe {
vec_znx_dft::vec_znx_idft( vec_znx_dft::vec_znx_idft(
self.0, self.0,

View File

@@ -412,6 +412,8 @@ impl VmpPMatOps for Module {
} }
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
unsafe { unsafe {
vmp::vmp_prepare_contiguous( vmp::vmp_prepare_contiguous(
self.0, self.0,
@@ -426,6 +428,14 @@ impl VmpPMatOps for Module {
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) { fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
#[cfg(debug_assertions)]
{
debug_assert_eq!(a.len(), b.rows);
a.iter().for_each(|ai| {
debug_assert_eq!(ai.len(), b.n * b.cols);
});
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
}
unsafe { unsafe {
vmp::vmp_prepare_dblptr( vmp::vmp_prepare_dblptr(
self.0, self.0,
@@ -439,7 +449,8 @@ impl VmpPMatOps for Module {
} }
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) { fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
debug_assert!(a.len() == b.cols() * self.n()); debug_assert_eq!(a.len(), b.cols() * self.n());
debug_assert!(buf.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
unsafe { unsafe {
vmp::vmp_prepare_row( vmp::vmp_prepare_row(
self.0, self.0,
@@ -478,6 +489,9 @@ impl VmpPMatOps for Module {
b: &VmpPMat, b: &VmpPMat,
buf: &mut [u8], buf: &mut [u8],
) { ) {
debug_assert!(
buf.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
unsafe { unsafe {
vmp::vmp_apply_dft( vmp::vmp_apply_dft(
self.0, self.0,
@@ -513,6 +527,10 @@ impl VmpPMatOps for Module {
} }
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) { fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) {
debug_assert!(
buf.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.0, self.0,
@@ -529,6 +547,10 @@ impl VmpPMatOps for Module {
} }
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
debug_assert!(
buf.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
);
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.0, self.0,