diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 16568b3..8e4c9da 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -32,16 +32,15 @@ fn main() { a.print_limbs(a.limbs(), n); println!(); - let mut vecznx: Vec= Vec::new(); - (0..rows).for_each(|_|{ + let mut vecznx: Vec = Vec::new(); + (0..rows).for_each(|_| { vecznx.push(module.new_vec_znx(cols)); }); - (0..rows).for_each(|i|{ - vecznx[i].data[i*n+1] = 1 as i64; + (0..rows).for_each(|i| { + vecznx[i].data[i * n + 1] = 1 as i64; }); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); @@ -57,7 +56,6 @@ fn main() { let mut values_res: Vec = vec![i64::default(); n]; res.decode_vec_i64(log_base2k, log_k, &mut values_res); - res.print_limbs(res.limbs(), n); module.free(); @@ -67,7 +65,6 @@ fn main() { //println!("{:?}", values_res) } - /* use base2k::{ @@ -110,4 +107,4 @@ fn main() { module.free(); } -*/ \ No newline at end of file +*/ diff --git a/base2k/src/ffi/vmp.rs b/base2k/src/ffi/vmp.rs index 9aeb184..154555c 100644 --- a/base2k/src/ffi/vmp.rs +++ b/base2k/src/ffi/vmp.rs @@ -92,9 +92,5 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vmp_prepare_tmp_bytes( - module: *const MODULE, - nrows: u64, - ncols: u64, - ) -> u64; + pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 377d765..24e22c1 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -408,26 +408,23 @@ pub trait VecZnxOps { fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); /// Splits b into subrings and copies them them into a. - /// + /// /// # Panics - /// + /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); - /// Merges the subrings a into b. - /// + /// /// # Panics - /// + /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); - } impl VecZnxOps for Module { - fn new_vec_znx(&self, limbs: usize) -> VecZnx { VecZnx::new(self.n(), limbs) } @@ -592,34 +589,48 @@ impl VecZnxOps for Module { } } - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx){ + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { let (n_in, n_out) = (a.n(), b[0].n()); - assert!(n_out < n_in, "invalid a: output ring degree should be smaller"); - b[1..].iter().for_each(|bi|{ - assert_eq!(bi.n(), n_out, "invalid input a: all VecZnx must have the same degree") + assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + b[1..].iter().for_each(|bi| { + assert_eq!( + bi.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) }); - b.iter_mut().enumerate().for_each(|(i, bi)|{ - if i == 0{ + b.iter_mut().enumerate().for_each(|(i, bi)| { + if i == 0 { a.switch_degree(bi); self.vec_znx_rotate(-1, buf, a); - }else{ + } else { buf.switch_degree(bi); self.vec_znx_rotate_inplace(-1, buf); } }) } - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec){ + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { let (n_in, n_out) = (b.n(), a[0].n()); - assert!(n_out < n_in, "invalid a: output ring degree should be smaller"); - a[1..].iter().for_each(|ai|{ - assert_eq!(ai.n(), n_out, "invalid input a: all VecZnx must have the same degree") + assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + a[1..].iter().for_each(|ai| { + assert_eq!( + ai.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) }); - a.iter().enumerate().for_each(|(i, ai)|{ + a.iter().enumerate().for_each(|(i, ai)| { ai.switch_degree(b); self.vec_znx_rotate_inplace(-1, b); }); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 874c83f..7113f00 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -5,6 +5,16 @@ use crate::{Infos, Module, VecZnx, VecZnxDft}; pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); impl VecZnxBig { + /// Casts a contiguous array of [u8] into as a [VecZnxDft]. + /// User must ensure that data is properly alligned and that + /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. + pub fn from_bytes(&self, limbs: usize, data: &mut [u8]) -> VecZnxBig { + VecZnxBig( + data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t, + limbs, + ) + } + pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1) } @@ -19,6 +29,10 @@ impl Module { unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } } + pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(self.0, limbs as u64) as usize } + } + // b <- b - a pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { let limbs: usize = a.limbs(); diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 74732b9..29fef62 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,10 +1,21 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; +use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; use crate::{Module, VecZnxBig}; pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); impl VecZnxDft { + /// Casts a contiguous array of [u8] into as a [VecZnxDft]. + /// User must ensure that data is properly alligned and that + /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. + pub fn from_bytes(&self, limbs: usize, data: &mut [u8]) -> VecZnxDft { + VecZnxDft(data.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, limbs) + } + + /// Cast a [VecZnxDft] into a [VecZnxBig]. + /// The returned [VecZnxBig] shares the backing array + /// with the original [VecZnxDft]. pub fn as_vec_znx_big(&mut self) -> VecZnxBig { VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1) } @@ -19,6 +30,10 @@ impl Module { unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } } + pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(self.0, limbs as u64) as usize } + } + // b <- IDFT(a), uses a as scratch space. pub fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { assert!(