diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index a7bdd59..e1f656f 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -12,6 +12,25 @@ pub struct VecZnxBig { } impl VecZnxBig { + + pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } + let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_big(cols, limbs)); + let ptr: *mut u8 = data.as_mut_ptr(); + Self { + data: data, + ptr: ptr, + n: module.n(), + cols: cols, + limbs: limbs, + _marker: PhantomData, + } + } + /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. @@ -64,22 +83,74 @@ impl VecZnxBig { } } - /// Returns a non-mutable reference to the entire contiguous array of the [VecZnxDft]. - pub fn raw(&self) -> &[i64] { - let ptr: *const i64 = self.ptr as *const i64; - unsafe { &std::slice::from_raw_parts(ptr, self.n() * self.poly_count()) } + /// Returns a non-mutable pointer to the backedn slice of the receiver. + pub fn as_ptr(&self) -> *const i64 { + self.ptr as *const i64 + } + + /// Returns a mutable pointer to the backedn slice of the receiver. + pub fn as_mut_ptr(&mut self) -> *mut i64 { + self.ptr as *mut i64 + } + + pub fn raw(&self) -> &[i64] { + unsafe { &std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } + } + + pub fn raw_mut(&mut self) -> &mut [i64] { + let ptr: *mut i64 = self.ptr as *mut i64; + let size: usize = self.n() * self.poly_count(); + unsafe { std::slice::from_raw_parts_mut(ptr, size) } + } + + pub fn at_ptr(&self, i: usize, j: usize) -> *const i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + let offset: usize = self.n * (j * self.cols() + i); + self.as_ptr().wrapping_add(offset) + } + + /// Returns a non-mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a non-mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n) } + } + + /// Returns a mutable pointer starting a the (i, j)-th small poly. + pub fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.limbs()); + } + + let offset: usize = self.n * (j * self.cols() + i); + self.as_mut_ptr().wrapping_add(offset) + } + + /// Returns a mutable reference to the i-th limb. + /// The returned array is of size [Self::n()] * [Self::cols()]. + pub fn at_limb_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, i), self.n * self.cols()) } + } + + /// Returns a mutable reference to the (i, j)-th poly. + /// The returned array is of size [Self::n()]. + pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n) } } - // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { - let raw: &[i64] = self.raw(); - (0..self.limbs()).for_each(|i| { - println!( - "{}: {:?}", - i, - &raw[i * self.n() * self.cols()..i * self.n() * self.cols() + n] - ) - }) + (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } @@ -192,21 +263,7 @@ pub trait VecZnxBigOps { impl VecZnxBigOps for Module { fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(limbs > 0); - } - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(cols, limbs)); - let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxBig:: { - data: data, - ptr: ptr, - n: self.n(), - cols: cols, - limbs: limbs, - _marker: PhantomData, - } + VecZnxBig::new(self, cols, limbs) } fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 61c2a85..d984cdd 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -16,6 +16,11 @@ pub struct VecZnxDft { impl VecZnxDft { pub fn new(module: &Module, cols: usize, limbs: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!(cols > 0); + assert!(limbs > 0); + } let mut data: Vec = alloc_aligned::(module.bytes_of_vec_znx_dft(cols, limbs)); let ptr: *mut u8 = data.as_mut_ptr(); Self {