diff --git a/base2k/README.md b/base2k/README.md index bbb334d..08a49e9 100644 --- a/base2k/README.md +++ b/base2k/README.md @@ -1,11 +1,12 @@ -# DISCLAIMER: ONLY TESTED ON UBUNTU +## WSL/Ubuntu To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule: -1) Initialize the sub-modile +1) Initialize the sub-module 2) $ cd base2k/spqlios-arithmetic 3) mdkir build 4) cd build 5) cmake .. 6) make +## Others Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options. \ No newline at end of file diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 092efcc..97f9f81 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -6,7 +6,8 @@ fn main() { let n: usize = 16; let log_base2k: usize = 18; let limbs: usize = 3; - let log_scale: usize = (limbs - 1) * log_base2k - 5; + let msg_limbs: usize = 2; + let log_scale: usize = msg_limbs * log_base2k - 5; let module: Module = Module::new::(n); let mut carry: Vec = vec![0; module.vec_znx_big_normalize_tmp_bytes()]; @@ -14,7 +15,7 @@ fn main() { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = VecZnx::new(n, log_base2k, limbs); + let mut res: VecZnx = module.new_vec_znx(limbs); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); @@ -27,8 +28,8 @@ fn main() { module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); - a.fill_uniform(&mut source, log_base2k * limbs); + let mut a: VecZnx = module.new_vec_znx(limbs); + a.fill_uniform(log_base2k, &mut source, limbs); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); @@ -42,23 +43,23 @@ fn main() { // buf_big <- IDFT(buf_dft) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs()); - let mut m: VecZnx = VecZnx::new(n, log_base2k, 2); + let mut m: VecZnx = module.new_vec_znx(msg_limbs); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); // m - m.from_i64(&want, 4, log_scale); - m.normalize(&mut carry); + m.from_i64(log_base2k, &want, 4, log_scale); + m.normalize(log_base2k, &mut carry); // buf_big <- m - buf_big module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); // b <- normalize(buf_big) + e - let mut b: VecZnx = VecZnx::new(n, log_base2k, limbs); - module.vec_znx_big_normalize(&mut b, &buf_big, &mut carry); - b.add_normal(&mut source, 3.2, 19.0, log_base2k * limbs); + let mut b: VecZnx = module.new_vec_znx(limbs); + module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); + b.add_normal(log_base2k, &mut source, 3.2, 19.0, log_base2k * limbs); //Decrypt @@ -70,11 +71,11 @@ fn main() { module.vec_znx_big_add_small_inplace(&mut buf_big, &b); // res <- normalize(buf_big) - module.vec_znx_big_normalize(&mut res, &buf_big, &mut carry); + module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.to_i64(&mut have, res.limbs() * log_base2k); + res.to_i64(log_base2k, &mut have, res.limbs() * log_base2k); let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 76e7247..89f64d3 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,4 +1,5 @@ -use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; +use base2k::vmp::VectorMatrixProduct; +use base2k::{Free, Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; use std::cmp::min; fn main() { @@ -22,9 +23,9 @@ fn main() { let mut a_values: Vec = vec![i64::default(); n]; a_values[1] = (1 << log_base2k) + 1; - let mut a: VecZnx = module.new_vec_znx(log_base2k, limbs); - a.from_i64(&a_values, 32, log_k); - a.normalize(&mut buf); + let mut a: VecZnx = module.new_vec_znx(limbs); + a.from_i64(log_base2k, &a_values, 32, log_k); + a.normalize(log_base2k, &mut buf); (0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i))); @@ -34,41 +35,26 @@ fn main() { b_mat.at_mut(i, i)[1] = 1 as i64; }); - println!(); - (0..rows).for_each(|i| { - (0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j))); - println!(); - }); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); - /* - (0..cols).for_each(|i| { - (0..rows).for_each(|j| println!("{} {}: {:?}", i, j, vmp_pmat.at(i, j))); - println!(); - }); - */ - - //println!("{:?}", vmp_pmat.as_f64()); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, cols); - let mut res: VecZnx = module.new_vec_znx(log_base2k, cols); - module.vec_znx_big_normalize(&mut res, &c_big, &mut buf); + let mut res: VecZnx = module.new_vec_znx(cols); + module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; - res.to_i64(&mut values_res, log_k); + res.to_i64(log_base2k, &mut values_res, log_k); (0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i))); - module.delete(); - c_dft.delete(); - vmp_pmat.delete(); + module.free(); + c_dft.free(); + vmp_pmat.free(); //println!("{:?}", values_res) } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 91c6054..4ef66ee 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -31,13 +31,13 @@ pub mod vec_znx_dft; #[allow(unused_imports)] pub use vec_znx_dft::*; -pub mod scalar_vector_product; +pub mod svp; #[allow(unused_imports)] -pub use scalar_vector_product::*; +pub use svp::*; -pub mod vector_matrix_product; +pub mod vmp; #[allow(unused_imports)] -pub use vector_matrix_product::*; +pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; @@ -65,3 +65,10 @@ pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] { let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } + +/// This trait should be implemented by structs that point to +/// memory allocated through C. +pub trait Free { + // Frees the memory and self destructs. + fn free(self); +} diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 2ce179a..6d1ce47 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,5 +1,5 @@ use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE}; -use crate::GALOISGENERATOR; +use crate::{Free, GALOISGENERATOR}; pub type MODULETYPE = u8; pub const FFT64: u8 = 0; @@ -53,8 +53,10 @@ impl Module { (gal_el as i64) * gen.signum() } +} - pub fn delete(self) { +impl Free for Module { + fn free(self) { unsafe { delete_module_info(self.0) } drop(self); } diff --git a/base2k/src/scalar_vector_product.rs b/base2k/src/svp.rs similarity index 51% rename from base2k/src/scalar_vector_product.rs rename to base2k/src/svp.rs index 3c2b13e..bac0416 100644 --- a/base2k/src/scalar_vector_product.rs +++ b/base2k/src/svp.rs @@ -1,33 +1,52 @@ use crate::ffi::svp::{delete_svp_ppol, new_svp_ppol, svp_apply_dft, svp_ppol_t, svp_prepare}; use crate::scalar::Scalar; -use crate::{Module, VecZnx, VecZnxDft}; +use crate::{Free, Module, VecZnx, VecZnxDft}; pub struct SvpPPol(pub *mut svp_ppol_t, pub usize); +/// A prepared [crate::Scalar] for [ScalarVectorProduct::svp_apply_dft]. +/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. +/// The backend array of an [SvpPPol] is allocated in C and must be freed manually. impl SvpPPol { + /// Returns the ring degree of the [SvpPPol]. pub fn n(&self) -> usize { self.1 } - pub fn delete(self) { + /// Returns the number of limbs of the [SvpPPol], which is always 1. + pub fn limbs(&self) -> usize { + 1 + } +} + +impl Free for SvpPPol { + fn free(self) { unsafe { delete_svp_ppol(self.0) }; let _ = drop(self); } } +pub trait ScalarVectorProduct { + /// Prepares a [crate::Scalar] for a [ScalarVectorProduct::svp_apply_dft]. + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); + + /// Allocates a new [SvpPPol]. + fn svp_new_ppol(&self) -> SvpPPol; + + /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of + /// the [VecZnxDft] is multiplied with [SvpPPol]. + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); +} + impl Module { - // Prepares a scalar polynomial (1 limb) for a scalar x vector product. - // Method will panic if a.limbs() != 1. pub fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } } - // Allocates a scalar-vector-product prepared-poly (VecZnxBig). pub fn svp_new_ppol(&self) -> SvpPPol { unsafe { SvpPPol(new_svp_ppol(self.0), self.n()) } } - // Applies a scalar x vector product: res <- a (ppol) x b pub fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { let limbs: u64 = b.limbs() as u64; assert!( diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 98628dc..3e775b6 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -9,23 +9,21 @@ use sampling::source::Source; use std::cmp::min; impl Module { - pub fn new_vec_znx(&self, log_base2k: usize, limbs: usize) -> VecZnx { - VecZnx::new(self.n(), log_base2k, limbs) + pub fn new_vec_znx(&self, limbs: usize) -> VecZnx { + VecZnx::new(self.n(), limbs) } } #[derive(Clone)] pub struct VecZnx { pub n: usize, - pub log_base2k: usize, pub data: Vec, } impl VecZnx { - pub fn new(n: usize, log_base2k: usize, limbs: usize) -> Self { + pub fn new(n: usize, limbs: usize) -> Self { Self { n: n, - log_base2k: log_base2k, data: vec![i64::default(); Self::buffer_size(n, limbs)], } } @@ -34,7 +32,7 @@ impl VecZnx { n * limbs } - pub fn from_buffer(&mut self, n: usize, log_base2k: usize, limbs: usize, buf: &[i64]) { + pub fn from_buffer(&mut self, n: usize, limbs: usize, buf: &[i64]) { let size = Self::buffer_size(n, limbs); assert!( buf.len() >= size, @@ -45,7 +43,6 @@ impl VecZnx { size ); self.n = n; - self.log_base2k = log_base2k; self.data = Vec::from(&buf[..size]) } @@ -94,25 +91,25 @@ impl VecZnx { unsafe { znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } } - pub fn from_i64(&mut self, data: &[i64], log_max: usize, log_k: usize) { - let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; + pub fn from_i64(&mut self, log_base2k: usize, data: &[i64], log_max: usize, log_k: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); let size: usize = min(data.len(), self.n()); - let log_k_rem: usize = self.log_base2k - (log_k % self.log_base2k); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == self.log_base2k { + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { (0..limbs - 1).for_each(|i| unsafe { znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); }); self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); } else { - let mask: i64 = (1 << self.log_base2k) - 1; - let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); (0..steps).for_each(|i| unsafe { znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); @@ -122,16 +119,16 @@ impl VecZnx { .rev() .enumerate() .for_each(|(i, i_rev)| { - let shift: usize = i * self.log_base2k; + let shift: usize = i * log_base2k; izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) .for_each(|(y, x)| *y = (x >> shift) & mask); }) } // Case where self.prec % self.k != 0. - if log_k_rem != self.log_base2k { + if log_k_rem != log_base2k { let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); (limbs - steps..limbs).rev().for_each(|i| { self.at_mut(i)[..size] .iter_mut() @@ -140,23 +137,30 @@ impl VecZnx { } } - pub fn from_i64_single(&mut self, i: usize, value: i64, log_max: usize, log_k: usize) { + pub fn from_i64_single( + &mut self, + log_base2k: usize, + i: usize, + value: i64, + log_max: usize, + log_k: usize, + ) { assert!(i < self.n()); - let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); - let log_k_rem: usize = self.log_base2k - (log_k % self.log_base2k); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); let limbs = self.limbs(); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == self.log_base2k { + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { (0..limbs - 1).for_each(|j| self.at_mut(j)[i] = 0); self.at_mut(self.limbs() - 1)[i] = value; } else { - let mask: i64 = (1 << self.log_base2k) - 1; - let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); (0..limbs - steps).for_each(|j| self.at_mut(j)[i] = 0); @@ -164,21 +168,21 @@ impl VecZnx { .rev() .enumerate() .for_each(|(j, j_rev)| { - self.at_mut(j_rev)[i] = (value >> (j * self.log_base2k)) & mask; + self.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } // Case where self.prec % self.k != 0. - if log_k_rem != self.log_base2k { + if log_k_rem != log_base2k { let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); (limbs - steps..limbs).rev().for_each(|j| { self.at_mut(j)[i] <<= log_k_rem; }) } } - pub fn normalize(&mut self, carry: &mut [u8]) { + pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { assert!( carry.len() >= self.n * 8, "invalid carry: carry.len()={} < self.n()={}", @@ -193,7 +197,7 @@ impl VecZnx { (0..self.limbs()).rev().for_each(|i| { znx_normalize( self.n as u64, - self.log_base2k as u64, + log_base2k as u64, self.at_mut_ptr(i), carry_i64.as_mut_ptr(), self.at_mut_ptr(i), @@ -203,8 +207,8 @@ impl VecZnx { } } - pub fn to_i64(&self, data: &mut [i64], log_k: usize) { - let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; + pub fn to_i64(&self, log_base2k: usize, data: &mut [i64], log_k: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; assert!( data.len() >= self.n, "invalid data: data.len()={} < self.n()={}", @@ -212,33 +216,33 @@ impl VecZnx { self.n ); data.copy_from_slice(self.at(0)); - let rem: usize = self.log_base2k - (log_k % self.log_base2k); + let rem: usize = log_base2k - (log_k % log_base2k); (1..limbs).for_each(|i| { - if i == limbs - 1 && rem != self.log_base2k { - let k_rem: usize = self.log_base2k - rem; + if i == limbs - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << self.log_base2k) + x; + *y = (*y << log_base2k) + x; }); } }) } - pub fn to_i64_single(&self, i: usize, log_k: usize) -> i64 { - let limbs: usize = (log_k + self.log_base2k - 1) / self.log_base2k; + pub fn to_i64_single(&self, log_base2k: usize, i: usize, log_k: usize) -> i64 { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; assert!(i < self.n()); let mut res: i64 = self.data[i]; - let rem: usize = self.log_base2k - (log_k % self.log_base2k); + let rem: usize = log_base2k - (log_k % log_base2k); (1..limbs).for_each(|i| { let x = self.data[i * self.n]; - if i == limbs - 1 && rem != self.log_base2k { - let k_rem: usize = self.log_base2k - rem; + if i == limbs - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; res = (res << k_rem) + (x >> rem); } else { - res = (res << self.log_base2k) + x; + res = (res << log_base2k) + x; } }); res @@ -259,38 +263,27 @@ impl VecZnx { } } - pub fn fill_uniform(&mut self, source: &mut Source, log_k: usize) { - let mut base2k: u64 = 1 << self.log_base2k; - let mut mask: u64 = base2k - 1; - let mut base2k_half: i64 = (base2k >> 1) as i64; + pub fn fill_uniform(&mut self, log_base2k: usize, source: &mut Source, limbs: usize) { + let base2k: u64 = 1 << log_base2k; + let mask: u64 = base2k - 1; + let base2k_half: i64 = (base2k >> 1) as i64; - let size: usize = self.n() * (self.limbs() - 1); + let size: usize = self.n() * (limbs - 1); self.data[..size] .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - - let log_base2k_rem: usize = log_k % self.log_base2k; - - if log_base2k_rem != 0 { - base2k = 1 << log_base2k_rem; - mask = (base2k - 1) << (self.log_base2k - log_base2k_rem); - base2k_half = ((mask >> 1) + 1) as i64; - } - - self.data[size..] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); } pub fn add_dist_f64>( &mut self, + log_base2k: usize, source: &mut Source, dist: T, bound: f64, log_k: usize, ) { - let log_base2k_rem: usize = log_k % self.log_base2k; + let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { @@ -311,29 +304,42 @@ impl VecZnx { } } - pub fn add_normal(&mut self, source: &mut Source, sigma: f64, bound: f64, log_k: usize) { - self.add_dist_f64(source, Normal::new(0.0, sigma).unwrap(), bound, log_k); + pub fn add_normal( + &mut self, + log_base2k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + log_k: usize, + ) { + self.add_dist_f64( + log_base2k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + log_k, + ); } - pub fn trunc_pow2(&mut self, k: usize) { + pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { if k == 0 { return; } self.data - .truncate((self.limbs() - k / self.log_base2k) * self.n()); + .truncate((self.limbs() - k / log_base2k) * self.n()); - let k_rem: usize = k % self.log_base2k; + let k_rem: usize = k % log_base2k; if k_rem != 0 { - let mask: i64 = ((1 << (self.log_base2k - k_rem - 1)) - 1) << k_rem; + let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; self.at_mut(self.limbs() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } } - pub fn rsh(&mut self, k: usize, carry: &mut [u8]) { + pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { assert!( carry.len() >> 3 >= self.n(), "invalid carry: carry.len()/8={} < self.n()={}", @@ -342,14 +348,14 @@ impl VecZnx { ); let limbs: usize = self.limbs(); - let limbs_steps: usize = k / self.log_base2k; + let limbs_steps: usize = k / log_base2k; self.data.rotate_right(self.n * limbs_steps); unsafe { znx_zero_i64_ref((self.n * limbs_steps) as u64, self.data.as_mut_ptr()); } - let k_rem = k % self.log_base2k; + let k_rem = k % log_base2k; if k_rem != 0 { let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); @@ -359,7 +365,7 @@ impl VecZnx { } let mask: i64 = (1 << k_rem) - 1; - let log_base2k: usize = self.log_base2k; + let log_base2k: usize = log_base2k; (limbs_steps..limbs).for_each(|i| { izip!(carry_i64.iter_mut(), self.at_mut(i).iter_mut()).for_each(|(ci, xi)| { @@ -410,14 +416,14 @@ mod tests { let log_base2k: usize = 17; let limbs: usize = 5; let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); + let mut a: VecZnx = VecZnx::new(n, limbs); let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .enumerate() .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); - a.from_i64(&have, 10, log_k); + a.from_i64(log_base2k, &have, 10, log_k); let mut want = vec![i64::default(); n]; - a.to_i64(&mut want, log_k); + a.to_i64(log_base2k, &mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); } @@ -427,7 +433,7 @@ mod tests { let log_base2k: usize = 17; let limbs: usize = 5; let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); + let mut a: VecZnx = VecZnx::new(n, limbs); let mut have: Vec = vec![i64::default(); n]; let mut source = Source::new([1; 32]); have.iter_mut().for_each(|x| { @@ -435,11 +441,11 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.from_i64(&have, 63, log_k); + a.from_i64(log_base2k, &have, 63, log_k); //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); let mut want = vec![i64::default(); n]; //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - a.to_i64(&mut want, log_k); + a.to_i64(log_base2k, &mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } #[test] @@ -448,7 +454,7 @@ mod tests { let log_base2k: usize = 17; let limbs: usize = 5; let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, log_base2k, limbs); + let mut a: VecZnx = VecZnx::new(n, limbs); let mut have: Vec = vec![i64::default(); n]; let mut source = Source::new([1; 32]); have.iter_mut().for_each(|x| { @@ -456,16 +462,16 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.from_i64(&have, 63, log_k); + a.from_i64(log_base2k, &have, 63, log_k); let mut carry: Vec = vec![u8::default(); n * 8]; - a.normalize(&mut carry); + a.normalize(log_base2k, &mut carry); let base_half = 1 << (log_base2k - 1); a.data .iter() .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); let mut want = vec![i64::default(); n]; - a.to_i64(&mut want, log_k); + a.to_i64(log_base2k, &mut want, log_k); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } } diff --git a/base2k/src/vec_znx_big_arithmetic.rs b/base2k/src/vec_znx_big_arithmetic.rs index 84d0396..32c6e74 100644 --- a/base2k/src/vec_znx_big_arithmetic.rs +++ b/base2k/src/vec_znx_big_arithmetic.rs @@ -4,7 +4,7 @@ use crate::ffi::vec_znx_big::{ vec_znx_bigcoeff_t, }; use crate::ffi::vec_znx_dft::vec_znx_dft_t; - +use crate::Free; use crate::{Module, VecZnx, VecZnxDft}; pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize); @@ -16,7 +16,10 @@ impl VecZnxBig { pub fn limbs(&self) -> usize { self.1 } - pub fn delete(self) { +} + +impl Free for VecZnxBig { + fn free(self) { unsafe { delete_vec_znx_big(self.0); } @@ -139,7 +142,13 @@ impl Module { } // b <- normalize(a) - pub fn vec_znx_big_normalize(&self, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + pub fn vec_znx_big_normalize( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &VecZnxBig, + tmp_bytes: &mut [u8], + ) { let limbs: usize = b.limbs(); assert!( b.limbs() >= limbs, @@ -156,7 +165,7 @@ impl Module { unsafe { vec_znx_big_normalize_base2k( self.0, - b.log_base2k as u64, + log_base2k as u64, b.as_mut_ptr(), limbs as u64, b.n() as u64, diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index c5f3bec..151bb5e 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -3,7 +3,7 @@ use crate::ffi::vec_znx_dft::{ delete_vec_znx_dft, new_vec_znx_dft, vec_znx_dft_t, vec_znx_idft, vec_znx_idft_tmp_a, vec_znx_idft_tmp_bytes, }; -use crate::{Module, VecZnxBig}; +use crate::{Free, Module, VecZnxBig}; pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize); @@ -14,8 +14,10 @@ impl VecZnxDft { pub fn limbs(&self) -> usize { self.1 } +} - pub fn delete(self) { +impl Free for VecZnxDft { + fn free(self) { unsafe { delete_vec_znx_dft(self.0) }; drop(self); } diff --git a/base2k/src/vector_matrix_product.rs b/base2k/src/vector_matrix_product.rs deleted file mode 100644 index dc098bf..0000000 --- a/base2k/src/vector_matrix_product.rs +++ /dev/null @@ -1,366 +0,0 @@ -use crate::ffi::vmp::{ - delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, - vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, - vmp_prepare_contiguous_tmp_bytes, -}; -use crate::{Module, VecZnx, VecZnxDft}; -use std::cmp::min; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -pub struct VmpPMat { - /// The pointer to the C memory. - pub data: *mut vmp_pmat_t, - /// The number of [VecZnx]. - pub rows: usize, - /// The number of limbs in each [VecZnx]. - pub cols: usize, - /// The ring degree of each [VecZnx]. - pub n: usize, -} - -impl VmpPMat { - - /// Returns the pointer to the [vmp_pmat_t]. - pub fn data(&self) -> *mut vmp_pmat_t { - self.data - } - - /// Returns the number of rows of the [VmpPMat]. - /// The number of rows (i.e. of [VecZnx]) of the [VmpPMat]. - pub fn rows(&self) -> usize { - self.rows - } - - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of limbs - /// of the prepared [VecZnx]. - pub fn cols(&self) -> usize { - self.cols - } - - /// Returns the ring dimension of the [VmpPMat]. - pub fn n(&self) -> usize { - self.n - } - - /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. - /// When using FFT64 as backend, T should be f64. - /// When using NTT120 as backend, T should be i64. - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = vec![T::default(); self.n]; - - if self.n < 8 { - res.copy_from_slice( - &self.get_backend_array::()[(row + col * self.rows()) * self.n() - ..(row + col * self.rows()) * (self.n() + 1)], - ); - } else { - (0..self.n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_array(row, col, blk)[..8]); - }); - } - - res - } - - /// When using FFT64 as backend, T should be f64. - /// When using NTT120 as backend, T should be i64. - fn get_array(&self, row: usize, col: usize, blk: usize) -> &[T] { - let nrows: usize = self.rows(); - let ncols: usize = self.cols(); - if col == (ncols - 1) && (ncols & 1 == 1) { - &self.get_backend_array::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] - } else { - &self.get_backend_array::()[blk * nrows * ncols * 8 - + (col / 2) * (2 * nrows) * 8 - + row * 2 * 8 - + (col % 2) * 8..] - } - } - - /// Returns a non-mutable reference of T to the entire contiguous array of the [VmpPMat]. - /// When using FFT64 as backend, T should be f64. - /// When using NTT120 as backend, T should be i64. - /// The length of the returned array is rows * cols * n. - pub fn get_backend_array(&self) -> &[T] { - let ptr: *const T = self.data as *const T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } - - /// frees the memory and self destructs. - pub fn delete(self) { - unsafe { delete_vmp_pmat(self.data) }; - drop(self); - } -} - -impl Module { - - /// Allocates a new [VmpPMat] with the given number of rows and columns. - pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { - unsafe { - VmpPMat { - data: new_vmp_pmat(self.0, rows as u64, cols as u64), - rows, - cols, - n: self.n(), - } - } - } - - /// Returns the number of bytes needed as scratch space for [Self::vmp_prepare_contiguous]. - pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } - } - - /// Prepares a [VmpPMat] given a contiguous array of [i64]. - /// The helper struct [Matrix3D] can be used to contruct the - /// appropriate contiguous array. - /// - /// # Example - /// ``` - /// let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); - /// - /// (0..min(rows, cols)).for_each(|i| { - /// b_mat.at_mut(i, i)[1] = 1 as i64; - /// }); - /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); - /// ``` - pub fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { - unsafe { - vmp_prepare_contiguous( - self.0, - b.data(), - a.as_ptr(), - b.rows() as u64, - b.cols() as u64, - buf.as_mut_ptr(), - ); - } - } - - pub fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]) { - let rows: usize = b.rows(); - let cols: usize = b.cols(); - - let mut mat: Matrix3D = Matrix3D::::new(rows, cols, self.n()); - - (0..min(rows, a.len())).for_each(|i| { - mat.set_row(i, &a[i].data); - }); - - self.vmp_prepare_contiguous(b, &mat.data, buf); - - /* - NOT IMPLEMENTED IN SPQLIOS - let mut ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect(); - unsafe { - vmp_prepare_dblptr( - self.0, - b.data(), - ptrs.as_mut_ptr(), - b.rows() as u64, - b.cols() as u64, - buf.as_mut_ptr(), - ); - } - */ - } - - pub fn vmp_apply_dft_tmp_bytes( - &self, - c_limbs: usize, - a_limbs: usize, - rows: usize, - cols: usize, - ) -> usize { - unsafe { - vmp_apply_dft_tmp_bytes( - self.0, - c_limbs as u64, - a_limbs as u64, - rows as u64, - cols as u64, - ) as usize - } - } - - pub fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) { - unsafe { - vmp_apply_dft( - self.0, - c.0, - c.limbs() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - b.data(), - b.rows() as u64, - b.cols() as u64, - buf.as_mut_ptr(), - ) - } - } - - pub fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - c_limbs: usize, - a_limbs: usize, - rows: usize, - cols: usize, - ) -> usize { - unsafe { - vmp_apply_dft_to_dft_tmp_bytes( - self.0, - c_limbs as u64, - a_limbs as u64, - rows as u64, - cols as u64, - ) as usize - } - } - - pub fn vmp_apply_dft_to_dft( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &VmpPMat, - buf: &mut [u8], - ) { - unsafe { - vmp_apply_dft_to_dft( - self.0, - c.0, - c.limbs() as u64, - a.0, - a.limbs() as u64, - b.data(), - b.rows() as u64, - b.cols() as u64, - buf.as_mut_ptr(), - ) - } - } - - pub fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { - unsafe { - vmp_apply_dft_to_dft( - self.0, - b.0, - b.limbs() as u64, - b.0, - b.limbs() as u64, - a.data(), - a.rows() as u64, - a.cols() as u64, - buf.as_mut_ptr(), - ) - } - } -} - -/// A helper struture that stores a 3D matrix as a contiguous array. -/// To be passed to [Module::vmp_prepare_contiguous]. -/// -/// rows: index of the i-th base2K power. -/// cols: index of the j-th limb of the i-th row. -/// n : polynomial degree. -/// -/// A [Matrix3D] can be seen as a vector of [VecZnx]. -pub struct Matrix3D { - pub data: Vec, - pub rows: usize, - pub cols: usize, - pub n: usize, -} - -impl Matrix3D { - /// Allocates a new [Matrix3D] with the respective dimensions. - /// - /// # Example - /// ``` - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// ``` - pub fn new(rows: usize, cols: usize, n: usize) -> Self { - let size = rows * cols * n; - Self { - data: vec![T::default(); size], - rows, - cols, - n, - } - } - - /// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D]. - /// The returned array is of size n. - /// - /// # Example - /// ``` - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let elem: &[i64] = mat.at(5, 5); // size n - /// ``` - pub fn at(&self, row: usize, col: usize) -> &[T] { - assert!(row <= self.rows && col <= self.cols); - let idx: usize = row * (self.n * self.cols) + col * self.n; - &self.data[idx..idx + self.n] - } - - /// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D]. - /// The returned array is of size n. - /// - /// # Example - /// ``` - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let elem: &mut [i64] = mat.at_mut(5, 5); // size n - /// ``` - pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { - assert!(row <= self.rows && col <= self.cols); - let idx: usize = row * (self.n * self.cols) + col * self.n; - &mut self.data[idx..idx + self.n] - } - - /// Sets the entry \[row\] of the [Matrix3D]. - /// Typicall this is used to assign a [VecZnx] to the i-th row - /// of the [Matrix3D]. - /// - /// # Example - /// ``` - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let a: Vec = VecZnx::new(n, cols); - /// - /// mat.set_row(1, &a.data); - /// ``` - pub fn set_row(&mut self, row: usize, a: &[T]) { - assert!( - row < self.rows, - "invalid argument row: row={} > self.rows={}", - row, - self.rows - ); - let idx: usize = row * (self.n * self.cols); - let size: usize = min(a.len(), self.cols * self.n); - self.data[idx..idx + size].copy_from_slice(&a[..size]); - } -} diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs new file mode 100644 index 0000000..1260c0d --- /dev/null +++ b/base2k/src/vmp.rs @@ -0,0 +1,592 @@ +use crate::ffi::vmp::{ + delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, + vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, + vmp_prepare_contiguous_tmp_bytes, +}; +use crate::Free; +use crate::{Module, VecZnx, VecZnxDft}; +use std::cmp::min; + +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. +/// Each row of the [VmpPMat] can be seen as a [VecZnxDft]. +/// +/// The backend array of [VmpPMat] is allocate in C, +/// and thus must be manually freed. +/// +/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. +/// See the trait [VectorMatrixProduct] for additional information. +pub struct VmpPMat { + /// The pointer to the C memory. + pub data: *mut vmp_pmat_t, + /// The number of [VecZnxDft]. + pub rows: usize, + /// The number of limbs in each [VecZnxDft]. + pub cols: usize, + /// The ring degree of each [VecZnxDft]. + pub n: usize, +} + +impl VmpPMat { + /// Returns the pointer to the [vmp_pmat_t]. + pub fn data(&self) -> *mut vmp_pmat_t { + self.data + } + + /// Returns the number of rows of the [VmpPMat]. + /// The number of rows (i.e. of [VecZnx]) of the [VmpPMat]. + pub fn rows(&self) -> usize { + self.rows + } + + /// Returns the number of cols of the [VmpPMat]. + /// The number of cols refers to the number of limbs + /// of the prepared [VecZnx]. + pub fn cols(&self) -> usize { + self.cols + } + + /// Returns the ring dimension of the [VmpPMat]. + pub fn n(&self) -> usize { + self.n + } + + /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + pub fn at(&self, row: usize, col: usize) -> Vec { + let mut res: Vec = vec![T::default(); self.n]; + + if self.n < 8 { + res.copy_from_slice( + &self.get_backend_array::()[(row + col * self.rows()) * self.n() + ..(row + col * self.rows()) * (self.n() + 1)], + ); + } else { + (0..self.n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_array(row, col, blk)[..8]); + }); + } + + res + } + + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + fn get_array(&self, row: usize, col: usize, blk: usize) -> &[T] { + let nrows: usize = self.rows(); + let ncols: usize = self.cols(); + if col == (ncols - 1) && (ncols & 1 == 1) { + &self.get_backend_array::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + } else { + &self.get_backend_array::()[blk * nrows * ncols * 8 + + (col / 2) * (2 * nrows) * 8 + + row * 2 * 8 + + (col % 2) * 8..] + } + } + + /// Returns a non-mutable reference of [T] of the entire contiguous array of the [VmpPMat]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is rows * cols * n. + pub fn get_backend_array(&self) -> &[T] { + let ptr: *const T = self.data as *const T; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } +} + +impl Free for VmpPMat { + fn free(self) { + unsafe { delete_vmp_pmat(self.data) }; + drop(self); + } +} + +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [VmpPMat]. +pub trait VectorMatrixProduct { + /// Allocates a new [VmpPMat] with the given number of rows and columns. + fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat; + + /// Returns the number of bytes needed as scratch space for [VectorMatrixProduct::vmp_prepare_contiguous]. + fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize; + + /// Prepares a [VmpPMat] from a contiguous array of [i64]. + /// The helper struct [Matrix3D] can be used to contruct and populate + /// the appropriate contiguous array. + /// + /// The size of buf can be obtained with [VectorMatrixProduct::vmp_prepare_contiguous_tmp_bytes]. + /// + /// # Example + /// ``` + /// use base2k::{Module, Matrix3D, VmpPMat, FFT64, Free}; + /// use base2k::vmp::VectorMatrixProduct; + /// use std::cmp::min; + /// + /// let n: usize = 1024; + /// let module = Module::new::(n); + /// let rows = 5; + /// let cols = 6; + /// + /// let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); + /// + /// // Populates the i-th row of b_math with X^1 * 2^(i * log_w) (here log_w is undefined) + /// (0..min(rows, cols)).for_each(|i| { + /// b_mat.at_mut(i, i)[1] = 1 as i64; + /// }); + /// + /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_contiguous_tmp_bytes(rows, cols)]; + /// + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + /// + /// vmp_pmat.free() // don't forget to free the memory once vmp_pmat is not needed anymore. + /// ``` + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); + + /// Prepares a [VmpPMat] from a vector of [VecZnx]. + /// + /// The size of buf can be obtained with [VectorMatrixProduct::vmp_prepare_contiguous_tmp_bytes]. + /// + /// # Example + /// ``` + /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VecZnx, Free}; + /// use base2k::vmp::VectorMatrixProduct; + /// use std::cmp::min; + /// + /// let n: usize = 1024; + /// let module: Module = Module::new::(n); + /// let rows: usize = 5; + /// let cols: usize = 6; + /// + /// let mut vecznx: Vec= Vec::new(); + /// (0..rows).for_each(|_|{ + /// vecznx.push(module.new_vec_znx(cols)); + /// }); + /// + /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_contiguous_tmp_bytes(rows, cols)]; + /// + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); + /// + /// vmp_pmat.free(); + /// module.free(); + /// ``` + fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]); + + /// Returns the size of the stratch space necessary for [VectorMatrixProduct::vmp_apply_dft]. + fn vmp_apply_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize; + + /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. + /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// + /// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and + /// `j` cols, the output is a [VecZnx] of `j` limbs. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Example + /// ``` + /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; + /// use base2k::vmp::VectorMatrixProduct; + /// + /// let n = 1024; + /// + /// let module: Module = Module::new::(n); + /// let limbs: usize = 5; + /// + /// let rows: usize = limbs; + /// let cols: usize = limbs + 1; + /// let c_limbs: usize = cols; + /// let a_limbs: usize = limbs; + /// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_limbs, a_limbs, rows, cols); + /// + /// let mut buf: Vec = vec![0; tmp_bytes]; + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// + /// let a: VecZnx = module.new_vec_znx(limbs); + /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); + /// module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); + /// + /// c_dft.free(); + /// vmp_pmat.free(); + /// module.free(); + /// ``` + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + + /// Returns the size of the stratch space necessary for [VectorMatrixProduct::vmp_apply_dft_to_dft]. + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize; + + /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. + /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// + /// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and + /// `j` cols, the output is a [VecZnx] of `j` limbs. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Example + /// ``` + /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; + /// use base2k::vmp::VectorMatrixProduct; + /// + /// let n = 1024; + /// + /// let module: Module = Module::new::(n); + /// let limbs: usize = 5; + /// + /// let rows: usize = limbs; + /// let cols: usize = limbs + 1; + /// let c_limbs: usize = cols; + /// let a_limbs: usize = limbs; + /// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_limbs, a_limbs, rows, cols); + /// + /// let mut buf: Vec = vec![0; tmp_bytes]; + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// + /// let a_dft: VecZnxDft = module.new_vec_znx_dft(limbs); + /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); + /// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf); + /// + /// a_dft.free(); + /// c_dft.free(); + /// vmp_pmat.free(); + /// module.free(); + /// ``` + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); + + /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. + /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// + /// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and + /// `j` cols, the output is a [VecZnx] of `j` limbs. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Example + /// ``` + /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; + /// use base2k::vmp::VectorMatrixProduct; + /// + /// let n = 1024; + /// + /// let module: Module = Module::new::(n); + /// let limbs: usize = 5; + /// + /// let rows: usize = limbs; + /// let cols: usize = limbs + 1; + /// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(limbs, limbs, rows, cols); + /// + /// let mut buf: Vec = vec![0; tmp_bytes]; + /// let a: VecZnx = module.new_vec_znx(limbs); + /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + /// + /// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs); + /// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf); + /// + /// c_dft.free(); + /// vmp_pmat.free(); + /// module.free(); + /// ``` + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); +} + +impl VectorMatrixProduct for Module { + fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { + unsafe { + VmpPMat { + data: new_vmp_pmat(self.0, rows as u64, cols as u64), + rows, + cols, + n: self.n(), + } + } + } + fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { + unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } + } + + fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { + unsafe { + vmp_prepare_contiguous( + self.0, + b.data(), + a.as_ptr(), + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ); + } + } + + fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]) { + let rows: usize = b.rows(); + let cols: usize = b.cols(); + + let mut mat: Matrix3D = Matrix3D::::new(rows, cols, self.n()); + + (0..min(rows, a.len())).for_each(|i| { + mat.set_row(i, &a[i].data); + }); + + self.vmp_prepare_contiguous(b, &mat.data, buf); + + /* + NOT IMPLEMENTED IN SPQLIOS + let mut ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect(); + unsafe { + vmp_prepare_dblptr( + self.0, + b.data(), + ptrs.as_mut_ptr(), + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ); + } + */ + } + + fn vmp_apply_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize { + unsafe { + vmp_apply_dft_tmp_bytes( + self.0, + c_limbs as u64, + a_limbs as u64, + rows as u64, + cols as u64, + ) as usize + } + } + + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) { + unsafe { + vmp_apply_dft( + self.0, + c.0, + c.limbs() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.data(), + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize { + unsafe { + vmp_apply_dft_to_dft_tmp_bytes( + self.0, + c_limbs as u64, + a_limbs as u64, + rows as u64, + cols as u64, + ) as usize + } + } + + fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) { + unsafe { + vmp_apply_dft_to_dft( + self.0, + c.0, + c.limbs() as u64, + a.0, + a.limbs() as u64, + b.data(), + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ) + } + } + + fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { + unsafe { + vmp_apply_dft_to_dft( + self.0, + b.0, + b.limbs() as u64, + b.0, + b.limbs() as u64, + a.data(), + a.rows() as u64, + a.cols() as u64, + buf.as_mut_ptr(), + ) + } + } +} + +/// A helper struture that stores a 3D matrix as a contiguous array. +/// To be passed to [VectorMatrixProduct::vmp_prepare_contiguous]. +/// +/// rows: index of the i-th base2K power. +/// cols: index of the j-th limb of the i-th row. +/// n : polynomial degree. +/// +/// A [Matrix3D] can be seen as a vector of [VecZnx]. +pub struct Matrix3D { + pub data: Vec, + pub rows: usize, + pub cols: usize, + pub n: usize, +} + +impl Matrix3D { + /// Allocates a new [Matrix3D] with the respective dimensions. + /// + /// # Example + /// ``` + /// use base2k::Matrix3D; + /// + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// ``` + pub fn new(rows: usize, cols: usize, n: usize) -> Self { + let size = rows * cols * n; + Self { + data: vec![T::default(); size], + rows, + cols, + n, + } + } + + /// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D]. + /// The returned array is of size n. + /// + /// # Example + /// ``` + /// use base2k::Matrix3D; + /// + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let elem: &[i64] = mat.at(4, 4); // size n + /// ``` + pub fn at(&self, row: usize, col: usize) -> &[T] { + assert!(row < self.rows && col < self.cols); + let idx: usize = row * (self.n * self.cols) + col * self.n; + &self.data[idx..idx + self.n] + } + + /// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D]. + /// The returned array is of size n. + /// + /// # Example + /// ``` + /// use base2k::Matrix3D; + /// + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let elem: &mut [i64] = mat.at_mut(4, 4); // size n + /// ``` + pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { + assert!(row < self.rows && col < self.cols); + let idx: usize = row * (self.n * self.cols) + col * self.n; + &mut self.data[idx..idx + self.n] + } + + /// Sets the entry \[row\] of the [Matrix3D]. + /// Typicall this is used to assign a [VecZnx] to the i-th row + /// of the [Matrix3D]. + /// + /// # Example + /// ``` + /// use base2k::{Matrix3D, VecZnx}; + /// + /// let rows = 5; // #decomp + /// let cols = 5; // #limbs + /// let n = 1024; // #coeffs + /// + /// let mut mat = Matrix3D::::new(rows, cols, n); + /// + /// let a: VecZnx = VecZnx::new(n, cols); + /// + /// mat.set_row(1, &a.data); + /// ``` + pub fn set_row(&mut self, row: usize, a: &[T]) { + assert!( + row < self.rows, + "invalid argument row: row={} > self.rows={}", + row, + self.rows + ); + let idx: usize = row * (self.n * self.cols); + let size: usize = min(a.len(), self.cols * self.n); + self.data[idx..idx + size].copy_from_slice(&a[..size]); + } +}