mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Ensures allocated memory is initialized
This commit is contained in:
@@ -2,6 +2,7 @@ use crate::cast_mut;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::ffi::znx;
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::is_aligned;
|
||||
use crate::{alias_mut_slice_to_vec, alloc_aligned};
|
||||
use crate::{Infos, Module};
|
||||
use itertools::izip;
|
||||
@@ -128,7 +129,7 @@ impl VecZnxApi for VecZnxBorrow {
|
||||
/// the size of data is at least equal to [VecZnx::bytes_of].
|
||||
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned {
|
||||
let size = Self::bytes_of(n, cols);
|
||||
assert!(
|
||||
debug_assert!(
|
||||
bytes.len() >= size,
|
||||
"invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}",
|
||||
bytes.len(),
|
||||
@@ -136,6 +137,7 @@ impl VecZnxApi for VecZnxBorrow {
|
||||
cols,
|
||||
size
|
||||
);
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
VecZnxBorrow {
|
||||
n: n,
|
||||
cols: cols,
|
||||
@@ -225,20 +227,20 @@ impl VecZnxApi for VecZnx {
|
||||
///
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [VecZnx::bytes_of].
|
||||
fn from_bytes(n: usize, cols: usize, buf: &mut [u8]) -> Self::Owned {
|
||||
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned {
|
||||
let size = Self::bytes_of(n, cols);
|
||||
assert!(
|
||||
buf.len() >= size,
|
||||
"invalid buffer: buf.len()={} < self.buffer_size(n={}, cols={})={}",
|
||||
buf.len(),
|
||||
debug_assert!(
|
||||
bytes.len() >= size,
|
||||
"invalid bytes: bytes.len()={} < self.bytes_of(n={}, cols={})={}",
|
||||
bytes.len(),
|
||||
n,
|
||||
cols,
|
||||
size
|
||||
);
|
||||
|
||||
debug_assert!(is_aligned(bytes.as_ptr()));
|
||||
VecZnx {
|
||||
n: n,
|
||||
data: alias_mut_slice_to_vec(cast_mut(&mut buf[..size])),
|
||||
data: alias_mut_slice_to_vec(cast_mut(&mut bytes[..size])),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +350,7 @@ impl VecZnx {
|
||||
pub fn new(n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
n: n,
|
||||
data: alloc_aligned::<i64>(n * cols, 64),
|
||||
data: alloc_aligned::<i64>(n * cols),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,17 +401,18 @@ pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
|
||||
});
|
||||
}
|
||||
|
||||
fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
|
||||
fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
assert!(
|
||||
carry.len() >= n * 8,
|
||||
"invalid carry: carry.len()={} < self.n()={}",
|
||||
carry.len(),
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= n * 8,
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.n()={}",
|
||||
tmp_bytes.len(),
|
||||
n
|
||||
);
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
|
||||
let carry_i64: &mut [i64] = cast_mut(carry);
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
@@ -426,16 +429,18 @@ fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) {
|
||||
pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
assert!(
|
||||
carry.len() >> 3 >= n,
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >> 3 >= n,
|
||||
"invalid carry: carry.len()/8={} < self.n()={}",
|
||||
carry.len() >> 3,
|
||||
tmp_bytes.len() >> 3,
|
||||
n
|
||||
);
|
||||
|
||||
debug_assert!(is_aligned(tmp_bytes.as_ptr()));
|
||||
|
||||
let cols: usize = a.cols();
|
||||
let cols_steps: usize = k / log_base2k;
|
||||
|
||||
@@ -447,7 +452,7 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut
|
||||
let k_rem = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let carry_i64: &mut [i64] = cast_mut(carry);
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
@@ -469,7 +474,6 @@ pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut
|
||||
pub trait VecZnxCommon: VecZnxApi + Infos {}
|
||||
|
||||
pub trait VecZnxOps {
|
||||
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -560,10 +564,8 @@ impl VecZnxOps for Module {
|
||||
self.n() * cols * 8
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize{
|
||||
unsafe{
|
||||
vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize
|
||||
}
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.0) as usize }
|
||||
}
|
||||
|
||||
// c <- a + b
|
||||
@@ -750,9 +752,9 @@ impl VecZnxOps for Module {
|
||||
a: &A,
|
||||
a_cols: usize,
|
||||
) {
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert!(a.cols() >= a_cols);
|
||||
debug_assert_eq!(a.n(), self.n());
|
||||
debug_assert_eq!(b.n(), self.n());
|
||||
debug_assert!(a.cols() >= a_cols);
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.0,
|
||||
@@ -803,8 +805,8 @@ impl VecZnxOps for Module {
|
||||
/// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
/// ```
|
||||
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) {
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert!(a.cols() >= a_cols);
|
||||
debug_assert_eq!(a.n(), self.n());
|
||||
debug_assert!(a.cols() >= a_cols);
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.0,
|
||||
@@ -827,12 +829,12 @@ impl VecZnxOps for Module {
|
||||
) {
|
||||
let (n_in, n_out) = (a.n(), b[0].n());
|
||||
|
||||
assert!(
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
b[1..].iter().for_each(|bi| {
|
||||
assert_eq!(
|
||||
debug_assert_eq!(
|
||||
bi.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
@@ -853,12 +855,12 @@ impl VecZnxOps for Module {
|
||||
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) {
|
||||
let (n_in, n_out) = (b.n(), a[0].n());
|
||||
|
||||
assert!(
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
assert_eq!(
|
||||
debug_assert_eq!(
|
||||
ai.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
|
||||
Reference in New Issue
Block a user