diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 1a76689..092efcc 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -4,9 +4,9 @@ use sampling::source::Source; fn main() { let n: usize = 16; - let log_base2k: usize = 40; - let prec: usize = 54; - let log_scale: usize = 18; + let log_base2k: usize = 18; + let limbs: usize = 3; + let log_scale: usize = (limbs - 1) * log_base2k - 5; let module: Module = Module::new::(n); let mut carry: Vec = vec![0; module.vec_znx_big_normalize_tmp_bytes()]; @@ -14,7 +14,7 @@ fn main() { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = VecZnx::new(n, log_base2k, prec); + let mut res: VecZnx = VecZnx::new(n, log_base2k, limbs); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); @@ -27,8 +27,8 @@ fn main() { module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = VecZnx::new(n, log_base2k, prec); - a.fill_uniform(&mut source); + let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); + a.fill_uniform(&mut source, log_base2k * limbs); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); @@ -42,22 +42,23 @@ fn main() { // buf_big <- IDFT(buf_dft) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs()); - let mut m: VecZnx = VecZnx::new(n, log_base2k, prec - log_scale); + let mut m: VecZnx = VecZnx::new(n, log_base2k, 2); + let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); // m - m.from_i64(&want, 4); + m.from_i64(&want, 4, log_scale); m.normalize(&mut carry); // buf_big <- m - buf_big module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); // b <- normalize(buf_big) + e - let mut b: VecZnx = VecZnx::new(n, log_base2k, prec); + let mut b: VecZnx = VecZnx::new(n, log_base2k, limbs); module.vec_znx_big_normalize(&mut b, &buf_big, &mut carry); - b.add_normal(&mut source, 3.2, 19.0); + b.add_normal(&mut source, 3.2, 19.0, log_base2k * limbs); //Decrypt @@ -73,9 +74,9 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.to_i64(&mut have); + res.to_i64(&mut have, res.limbs() * log_base2k); - let scale: f64 = (1 << log_scale) as f64; + let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index c93a7f07e..76e7247 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,4 +1,5 @@ use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; +use std::cmp::min; fn main() { let log_n = 5; @@ -6,10 +7,10 @@ fn main() { let module: Module = Module::new::(n); let log_base2k: usize = 15; - let log_q: usize = 60; - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let limbs: usize = 5; + let log_k: usize = log_base2k * limbs - 5; - let rows: usize = limbs + 1; + let rows: usize = limbs; let cols: usize = limbs + 1; // Maximum size of the byte scratch needed @@ -21,23 +22,18 @@ fn main() { let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; - let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q); - a.from_i64(&a_values, 32); + let mut a: VecZnx = module.new_vec_znx(log_base2k, limbs); + a.from_i64(&a_values, 32, log_k); a.normalize(&mut buf); (0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i))); let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); - (0..rows).for_each(|i| { - (0..cols).for_each(|j| { - b_mat.at_mut(i, j)[0] = (i * cols + j) as i64; - b_mat.at_mut(i, j)[0] = (i * cols + j) as i64; - }) + (0..min(rows, cols)).for_each(|i| { + b_mat.at_mut(i, i)[1] = 1 as i64; }); - //b_mat.data.iter_mut().enumerate().for_each(|(i, xi)| *xi = i as i64); - println!(); (0..rows).for_each(|i| { (0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j))); @@ -47,24 +43,26 @@ fn main() { let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + /* (0..cols).for_each(|i| { (0..rows).for_each(|j| println!("{} {}: {:?}", i, j, vmp_pmat.at(i, j))); println!(); }); + */ - println!("{:?}", vmp_pmat.as_f64()); + //println!("{:?}", vmp_pmat.as_f64()); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs); + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, limbs); + module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, cols); - let mut res: VecZnx = module.new_vec_znx(log_base2k, log_q); + let mut res: VecZnx = module.new_vec_znx(log_base2k, cols); module.vec_znx_big_normalize(&mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; - res.to_i64(&mut values_res); + res.to_i64(&mut values_res, log_k); (0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i))); @@ -72,5 +70,5 @@ fn main() { c_dft.delete(); vmp_pmat.delete(); - println!("{:?}", values_res) + //println!("{:?}", values_res) } diff --git a/base2k/src/ffi/vec_znx.rs b/base2k/src/ffi/vec_znx.rs index 897ef04..26dec8c 100644 --- a/base2k/src/ffi/vec_znx.rs +++ b/base2k/src/ffi/vec_znx.rs @@ -69,10 +69,10 @@ unsafe extern "C" { } unsafe extern "C" { - pub fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64); + pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64); } unsafe extern "C" { - pub fn vec_znx_copy( + pub unsafe fn vec_znx_copy( module: *const MODULE, res: *mut i64, res_size: u64, @@ -84,7 +84,7 @@ unsafe extern "C" { } unsafe extern "C" { - pub fn vec_znx_normalize_base2k( + pub unsafe fn vec_znx_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -97,5 +97,5 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; } diff --git a/base2k/src/ffi/vec_znx_dft.rs b/base2k/src/ffi/vec_znx_dft.rs index 6f43683..54fb117 100644 --- a/base2k/src/ffi/vec_znx_dft.rs +++ b/base2k/src/ffi/vec_znx_dft.rs @@ -19,10 +19,10 @@ unsafe extern "C" { } unsafe extern "C" { - pub fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); + pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); } unsafe extern "C" { - pub fn vec_dft_add( + pub unsafe fn vec_dft_add( module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, @@ -33,7 +33,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_dft_sub( + pub unsafe fn vec_dft_sub( module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, @@ -44,7 +44,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_dft( + pub unsafe fn vec_znx_dft( module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, @@ -54,7 +54,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_idft( + pub unsafe fn vec_znx_idft( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -64,10 +64,10 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; } unsafe extern "C" { - pub fn vec_znx_idft_tmp_a( + pub unsafe fn vec_znx_idft_tmp_a( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, diff --git a/base2k/src/ffi/vmp.rs b/base2k/src/ffi/vmp.rs index e51e31c..44ed6c0 100644 --- a/base2k/src/ffi/vmp.rs +++ b/base2k/src/ffi/vmp.rs @@ -6,6 +6,8 @@ use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; pub struct vmp_pmat_t { _unused: [u8; 0], } + +// [rows][cols] = [#Decomposition][#Limbs] pub type VMP_PMAT = vmp_pmat_t; unsafe extern "C" { @@ -77,6 +79,9 @@ unsafe extern "C" { tmp_space: *mut u8, ); } + +/* +NOT IMPLEMENTED IN SPQLIOS unsafe extern "C" { pub unsafe fn vmp_prepare_dblptr( module: *const MODULE, @@ -87,6 +92,7 @@ unsafe extern "C" { tmp_space: *mut u8, ); } +*/ unsafe extern "C" { pub unsafe fn vmp_prepare_contiguous_tmp_bytes( module: *const MODULE, diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 534d5d0..ff52f83 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -18,38 +18,34 @@ impl Module { pub struct VecZnx { pub n: usize, pub log_base2k: usize, - pub log_q: usize, pub data: Vec, } impl VecZnx { - pub fn new(n: usize, log_base2k: usize, log_q: usize) -> Self { + pub fn new(n: usize, log_base2k: usize, limbs: usize) -> Self { Self { n: n, log_base2k: log_base2k, - log_q: log_q, - data: vec![i64::default(); Self::buffer_size(n, log_base2k, log_q)], + data: vec![i64::default(); Self::buffer_size(n, limbs)], } } - pub fn buffer_size(n: usize, log_base2k: usize, log_q: usize) -> usize { - n * ((log_q + log_base2k - 1) / log_base2k) + pub fn buffer_size(n: usize, limbs: usize) -> usize { + n * limbs } - pub fn from_buffer(&mut self, n: usize, log_base2k: usize, log_q: usize, buf: &[i64]) { - let size = Self::buffer_size(n, log_base2k, log_q); + pub fn from_buffer(&mut self, n: usize, log_base2k: usize, limbs: usize, buf: &[i64]) { + let size = Self::buffer_size(n, limbs); assert!( buf.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={}, k={}, log_q={})={}", + "invalid buffer: buf.len()={} < self.buffer_size(n={}, limbs={})={}", buf.len(), n, - log_base2k, - log_q, + limbs, size ); self.n = n; self.log_base2k = log_base2k; - self.log_q = log_q; self.data = Vec::from(&buf[..size]) } @@ -61,10 +57,6 @@ impl VecZnx { self.n } - pub fn log_q(&self) -> usize { - self.log_q - } - pub fn limbs(&self) -> usize { self.data.len() / self.n } @@ -102,20 +94,21 @@ impl VecZnx { unsafe { znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } } - pub fn from_i64(&mut self, data: &[i64], log_max: usize) { - let size: usize = min(data.len(), self.n()); - let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k); + pub fn from_i64(&mut self, data: &[i64], log_max: usize, log_k: usize) { + let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; - let limbs: usize = self.limbs(); + assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); + + let size: usize = min(data.len(), self.n()); + let log_k_rem: usize = self.log_base2k - (log_k % self.log_base2k); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == self.log_base2k { + if log_max + log_k_rem < 63 || log_k_rem == self.log_base2k { (0..limbs - 1).for_each(|i| unsafe { znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); }); - self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); } else { let mask: i64 = (1 << self.log_base2k) - 1; @@ -136,24 +129,28 @@ impl VecZnx { } // Case where self.prec % self.k != 0. - if k_rem != self.log_base2k { + if log_k_rem != self.log_base2k { let limbs = self.limbs(); let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); (limbs - steps..limbs).rev().for_each(|i| { - self.at_mut(i)[..size].iter_mut().for_each(|x| *x <<= k_rem); + self.at_mut(i)[..size] + .iter_mut() + .for_each(|x| *x <<= log_k_rem); }) } } - pub fn from_i64_single(&mut self, i: usize, value: i64, log_max: usize) { + pub fn from_i64_single(&mut self, i: usize, value: i64, log_max: usize, log_k: usize) { assert!(i < self.n()); - let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k); + let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; + assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); + let log_k_rem: usize = self.log_base2k - (log_k % self.log_base2k); let limbs = self.limbs(); - // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == self.log_base2k { + if log_max + log_k_rem < 63 || log_k_rem == self.log_base2k { (0..limbs - 1).for_each(|j| self.at_mut(j)[i] = 0); self.at_mut(self.limbs() - 1)[i] = value; @@ -172,11 +169,11 @@ impl VecZnx { } // Case where self.prec % self.k != 0. - if k_rem != self.log_base2k { + if log_k_rem != self.log_base2k { let limbs = self.limbs(); let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); (limbs - steps..limbs).rev().for_each(|j| { - self.at_mut(j)[i] <<= k_rem; + self.at_mut(j)[i] <<= log_k_rem; }) } } @@ -206,7 +203,8 @@ impl VecZnx { } } - pub fn to_i64(&self, data: &mut [i64]) { + pub fn to_i64(&self, data: &mut [i64], log_k: usize) { + let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; assert!( data.len() >= self.n, "invalid data: data.len()={} < self.n()={}", @@ -214,9 +212,9 @@ impl VecZnx { self.n ); data.copy_from_slice(self.at(0)); - let rem: usize = self.log_base2k - (self.log_q % self.log_base2k); - (1..self.limbs()).for_each(|i| { - if i == self.limbs() - 1 && rem != self.log_base2k { + let rem: usize = self.log_base2k - (log_k % self.log_base2k); + (1..limbs).for_each(|i| { + if i == limbs - 1 && rem != self.log_base2k { let k_rem: usize = self.log_base2k - rem; izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); @@ -229,13 +227,14 @@ impl VecZnx { }) } - pub fn to_i64_single(&self, i: usize) -> i64 { + pub fn to_i64_single(&self, i: usize, log_k: usize) -> i64 { + let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; assert!(i < self.n()); let mut res: i64 = self.data[i]; - let rem: usize = self.log_base2k - (self.log_q % self.log_base2k); - (1..self.limbs()).for_each(|i| { + let rem: usize = self.log_base2k - (log_k % self.log_base2k); + (1..limbs).for_each(|i| { let x = self.data[i * self.n]; - if i == self.limbs() - 1 && rem != self.log_base2k { + if i == limbs - 1 && rem != self.log_base2k { let k_rem: usize = self.log_base2k - rem; res = (res << k_rem) + (x >> rem); } else { @@ -260,7 +259,7 @@ impl VecZnx { } } - pub fn fill_uniform(&mut self, source: &mut Source) { + pub fn fill_uniform(&mut self, source: &mut Source, log_k: usize) { let mut base2k: u64 = 1 << self.log_base2k; let mut mask: u64 = base2k - 1; let mut base2k_half: i64 = (base2k >> 1) as i64; @@ -271,7 +270,7 @@ impl VecZnx { .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - let log_base2k_rem: usize = self.log_q % self.log_base2k; + let log_base2k_rem: usize = log_k % self.log_base2k; if log_base2k_rem != 0 { base2k = 1 << log_base2k_rem; @@ -284,8 +283,14 @@ impl VecZnx { .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); } - pub fn add_dist_f64>(&mut self, source: &mut Source, dist: T, bound: f64) { - let log_base2k_rem: usize = self.log_q % self.log_base2k; + pub fn add_dist_f64>( + &mut self, + source: &mut Source, + dist: T, + bound: f64, + log_k: usize, + ) { + let log_base2k_rem: usize = log_k % self.log_base2k; if log_base2k_rem != 0 { self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { @@ -306,8 +311,8 @@ impl VecZnx { } } - pub fn add_normal(&mut self, source: &mut Source, sigma: f64, bound: f64) { - self.add_dist_f64(source, Normal::new(0.0, sigma).unwrap(), bound); + pub fn add_normal(&mut self, source: &mut Source, sigma: f64, bound: f64, log_k: usize) { + self.add_dist_f64(source, Normal::new(0.0, sigma).unwrap(), bound, log_k); } pub fn trunc_pow2(&mut self, k: usize) { @@ -315,14 +320,6 @@ impl VecZnx { return; } - assert!( - k <= self.log_q, - "invalid argument k: k={} > self.prec()={}", - k, - self.log_q() - ); - - self.log_q -= k; self.data .truncate((self.limbs() - k / self.log_base2k) * self.n()); @@ -337,6 +334,13 @@ impl VecZnx { } pub fn rsh(&mut self, k: usize, carry: &mut [u8]) { + assert!( + carry.len() >> 3 >= self.n(), + "invalid carry: carry.len()/8={} < self.n()={}", + carry.len() >> 3, + self.n() + ); + let limbs: usize = self.limbs(); let limbs_steps: usize = k / self.log_base2k; @@ -388,6 +392,10 @@ impl VecZnx { .for_each(|(x_in, x_out)| *x_out = *x_in); }); } + + pub fn print_limbs(&self, limbs: usize, n: usize) { + (0..limbs).for_each(|i| println!("{}: {:?}", i, &self.at(i)[..n])) + } } #[cfg(test)] @@ -398,26 +406,28 @@ mod tests { #[test] fn test_set_get_i64_lo_norm() { - let n: usize = 32; - let k: usize = 19; - let prec: usize = 128; - let mut a: VecZnx = VecZnx::new(n, k, prec); + let n: usize = 8; + let log_base2k: usize = 17; + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .enumerate() .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); - a.from_i64(&have, 10); + a.from_i64(&have, 10, log_k); let mut want = vec![i64::default(); n]; - a.to_i64(&mut want); + a.to_i64(&mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); } #[test] fn test_set_get_i64_hi_norm() { let n: usize = 8; - let k: usize = 17; - let prec: usize = 84; - let mut a: VecZnx = VecZnx::new(n, k, prec); + let log_base2k: usize = 17; + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); let mut have: Vec = vec![i64::default(); n]; let mut source = Source::new([1; 32]); have.iter_mut().for_each(|x| { @@ -425,19 +435,20 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.from_i64(&have, 63); + a.from_i64(&have, 63, log_k); //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); let mut want = vec![i64::default(); n]; //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - a.to_i64(&mut want); + a.to_i64(&mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } #[test] fn test_normalize() { let n: usize = 8; - let k: usize = 17; - let prec: usize = 84; - let mut a: VecZnx = VecZnx::new(n, k, prec); + let log_base2k: usize = 17; + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); let mut have: Vec = vec![i64::default(); n]; let mut source = Source::new([1; 32]); have.iter_mut().for_each(|x| { @@ -445,16 +456,16 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.from_i64(&have, 63); + a.from_i64(&have, 63, log_k); let mut carry: Vec = vec![u8::default(); n * 8]; a.normalize(&mut carry); - let base_half = 1 << (k - 1); + let base_half = 1 << (log_base2k - 1); a.data .iter() .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); let mut want = vec![i64::default(); n]; - a.to_i64(&mut want); + a.to_i64(&mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } } diff --git a/base2k/src/vector_matrix_product.rs b/base2k/src/vector_matrix_product.rs index a546a91..0136511 100644 --- a/base2k/src/vector_matrix_product.rs +++ b/base2k/src/vector_matrix_product.rs @@ -1,9 +1,10 @@ use crate::ffi::vmp::{ delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, - vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr, + vmp_prepare_contiguous_tmp_bytes, }; use crate::{Module, VecZnx, VecZnxDft}; +use std::cmp::min; pub struct VmpPMat { pub data: *mut vmp_pmat_t, @@ -110,19 +111,32 @@ impl Module { } } - pub fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<&Vec>, buf: &mut [u8]) { - let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); + pub fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]) { + let rows: usize = b.rows(); + let cols: usize = b.cols(); + let mut mat: Matrix3D = Matrix3D::::new(rows, cols, self.n()); + + (0..min(rows, a.len())).for_each(|i| { + mat.set_row(i, &a[i].data); + }); + + self.vmp_prepare_contiguous(b, &mat.data, buf); + + /* + NOT IMPLEMENTED IN SPQLIOS + let mut ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect(); unsafe { vmp_prepare_dblptr( self.0, b.data(), - ptrs.as_ptr(), + ptrs.as_mut_ptr(), b.rows() as u64, b.cols() as u64, buf.as_mut_ptr(), ); } + */ } pub fn vmp_apply_dft_tmp_bytes( @@ -237,18 +251,25 @@ impl Matrix3D { pub fn at(&self, row: usize, col: usize) -> &[T] { assert!(row <= self.rows && col <= self.cols); - let idx: usize = col * (self.n * self.rows) + row * self.n; + let idx: usize = row * (self.n * self.cols) + col * self.n; &self.data[idx..idx + self.n] } pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { assert!(row <= self.rows && col <= self.cols); - let idx: usize = col * (self.n * self.rows) + row * self.n; + let idx: usize = row * (self.n * self.cols) + col * self.n; &mut self.data[idx..idx + self.n] } - pub fn set_col(&mut self, col: usize, a: &[T]) { - let idx: usize = col * (self.n * self.rows); - self.data[idx..idx + self.rows * self.n].copy_from_slice(a); + pub fn set_row(&mut self, row: usize, a: &[T]) { + assert!( + row < self.rows, + "invalid argument row: row={} > self.rows={}", + row, + self.rows + ); + let idx: usize = row * (self.n * self.cols); + let size: usize = min(a.len(), self.cols * self.n); + self.data[idx..idx + size].copy_from_slice(&a[..size]); } }