From a790ff37ccf30eadd099f2436598ce1c17d4050c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 4 Feb 2025 17:13:46 +0100 Subject: [PATCH] more doc --- base2k/examples/rlwe_encrypt.rs | 13 +- base2k/examples/vector_matrix_product.rs | 10 +- base2k/src/encoding.rs | 236 ++++++ base2k/src/free.rs | 43 ++ base2k/src/infos.rs | 77 ++ base2k/src/lib.rs | 35 +- base2k/src/sampling.rs | 94 +++ base2k/src/scalar.rs | 55 -- base2k/src/svp.rs | 88 ++- base2k/src/vec_znx.rs | 692 ++++++++++-------- base2k/src/vec_znx_arithmetic.rs | 168 ----- ...c_znx_big_arithmetic.rs => vec_znx_big.rs} | 56 +- base2k/src/vec_znx_dft.rs | 28 +- base2k/src/vmp.rs | 185 +++-- 14 files changed, 1097 insertions(+), 683 deletions(-) create mode 100644 base2k/src/encoding.rs create mode 100644 base2k/src/free.rs create mode 100644 base2k/src/infos.rs create mode 100644 base2k/src/sampling.rs delete mode 100644 base2k/src/scalar.rs delete mode 100644 base2k/src/vec_znx_arithmetic.rs rename base2k/src/{vec_znx_big_arithmetic.rs => vec_znx_big.rs} (78%) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 97f9f81..729d895 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,4 +1,7 @@ -use base2k::{Module, Scalar, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, FFT64}; +use base2k::{ + Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, + VecZnxOps, FFT64, +}; use itertools::izip; use sampling::source::Source; @@ -29,7 +32,7 @@ fn main() { // a <- Z_{2^prec}[X]/(X^{N}+1) let mut a: VecZnx = module.new_vec_znx(limbs); - a.fill_uniform(log_base2k, &mut source, limbs); + a.fill_uniform(log_base2k, limbs, &mut source); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); @@ -50,7 +53,7 @@ fn main() { .for_each(|x| *x = source.next_u64n(16, 15) as i64); // m - m.from_i64(log_base2k, &want, 4, log_scale); + m.encode_i64_vec(log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); // buf_big <- m - buf_big @@ -59,7 +62,7 @@ fn main() { // b <- normalize(buf_big) + e let mut b: VecZnx = module.new_vec_znx(limbs); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); - b.add_normal(log_base2k, &mut source, 3.2, 19.0, log_base2k * limbs); + b.add_normal(log_base2k, log_base2k * limbs, &mut source, 3.2, 19.0); //Decrypt @@ -75,7 +78,7 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.to_i64(log_base2k, &mut have, res.limbs() * log_base2k); + res.decode_i64_vec(log_base2k, res.limbs() * log_base2k, &mut have); let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 89f64d3..1bfb1b6 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,5 +1,7 @@ -use base2k::vmp::VectorMatrixProduct; -use base2k::{Free, Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; +use base2k::{ + Encoding, Free, Infos, Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMat, + VmpPMatOps, FFT64, +}; use std::cmp::min; fn main() { @@ -24,7 +26,7 @@ fn main() { a_values[1] = (1 << log_base2k) + 1; let mut a: VecZnx = module.new_vec_znx(limbs); - a.from_i64(log_base2k, &a_values, 32, log_k); + a.encode_i64_vec(log_base2k, log_k, &a_values, 32); a.normalize(log_base2k, &mut buf); (0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i))); @@ -48,7 +50,7 @@ fn main() { module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; - res.to_i64(log_base2k, &mut values_res, log_k); + res.decode_i64_vec(log_base2k, log_k, &mut values_res); (0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i))); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs new file mode 100644 index 0000000..6ca3ac3 --- /dev/null +++ b/base2k/src/encoding.rs @@ -0,0 +1,236 @@ +use crate::ffi::znx::znx_zero_i64_ref; +use crate::{Infos, VecZnx}; +use itertools::izip; +use std::cmp::min; + +pub trait Encoding { + /// encode a vector of i64 on the receiver. + /// + /// # Arguments + /// + /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `log_k`: base two logarithm of the scaling of the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_i64_vec(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + + /// decode a vector of i64 from the receiver. + /// + /// # Arguments + /// + /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `log_k`: base two logarithm of the scaling of the data. + /// * `data`: data to decode from the receiver. + fn decode_i64_vec(&self, log_base2k: usize, log_k: usize, data: &mut [i64]); + + /// encodes a single i64 on the receiver at the given index. + /// + /// # Arguments + /// + /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `log_k`: base two logarithm of the scaling of the data. + /// * `i`: index of the coefficient on which to encode the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_i64_coeff( + &mut self, + log_base2k: usize, + log_k: usize, + i: usize, + data: i64, + log_max: usize, + ); + + /// decode a single of i64 from the receiver at the given index. + /// + /// # Arguments + /// + /// * `log_base2k`: base two logarithm decomposition of the receiver. + /// * `log_k`: base two logarithm of the scaling of the data. + /// * `i`: index of the coefficient to decode. + /// * `data`: data to decode from the receiver. + fn decode_i64_coeff(&self, log_base2k: usize, log_k: usize, i: usize) -> i64; +} + +impl Encoding for VecZnx { + fn encode_i64_vec(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + + assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); + + let size: usize = min(data.len(), self.n()); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); + + // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { + (0..limbs - 1).for_each(|i| unsafe { + znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); + }); + self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); + } else { + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + + (0..steps).for_each(|i| unsafe { + znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); + }); + + (limbs - steps..limbs) + .rev() + .enumerate() + .for_each(|(i, i_rev)| { + let shift: usize = i * log_base2k; + izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) + .for_each(|(y, x)| *y = (x >> shift) & mask); + }) + } + + // Case where self.prec % self.k != 0. + if log_k_rem != log_base2k { + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|i| { + self.at_mut(i)[..size] + .iter_mut() + .for_each(|x| *x <<= log_k_rem); + }) + } + } + + fn decode_i64_vec(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + assert!( + data.len() >= self.n, + "invalid data: data.len()={} < self.n()={}", + data.len(), + self.n + ); + data.copy_from_slice(self.at(0)); + let rem: usize = log_base2k - (log_k % log_base2k); + (1..limbs).for_each(|i| { + if i == limbs - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; + izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << log_base2k) + x; + }); + } + }) + } + + fn encode_i64_coeff( + &mut self, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, + ) { + assert!(i < self.n()); + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); + let limbs = self.limbs(); + + // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { + (0..limbs - 1).for_each(|j| self.at_mut(j)[i] = 0); + + self.at_mut(self.limbs() - 1)[i] = value; + } else { + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + + (0..limbs - steps).for_each(|j| self.at_mut(j)[i] = 0); + + (limbs - steps..limbs) + .rev() + .enumerate() + .for_each(|(j, j_rev)| { + self.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; + }) + } + + // Case where self.prec % self.k != 0. + if log_k_rem != log_base2k { + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); + (limbs - steps..limbs).rev().for_each(|j| { + self.at_mut(j)[i] <<= log_k_rem; + }) + } + } + + fn decode_i64_coeff(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 { + let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + assert!(i < self.n()); + let mut res: i64 = self.data[i]; + let rem: usize = log_base2k - (log_k % log_base2k); + (1..limbs).for_each(|i| { + let x = self.data[i * self.n]; + if i == limbs - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; + res = (res << k_rem) + (x >> rem); + } else { + res = (res << log_base2k) + x; + } + }); + res + } +} + +#[cfg(test)] +mod tests { + use crate::{Encoding, VecZnx}; + use itertools::izip; + use sampling::source::Source; + + #[test] + fn test_set_get_i64_lo_norm() { + let n: usize = 8; + let log_base2k: usize = 17; + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(n, limbs); + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); + a.encode_i64_vec(log_base2k, log_k, &have, 10); + let mut want = vec![i64::default(); n]; + a.decode_i64_vec(log_base2k, log_k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + } + + #[test] + fn test_set_get_i64_hi_norm() { + let n: usize = 8; + let log_base2k: usize = 17; + let limbs: usize = 5; + let log_k: usize = limbs * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(n, limbs); + let mut have: Vec = vec![i64::default(); n]; + let mut source = Source::new([1; 32]); + have.iter_mut().for_each(|x| { + *x = source + .next_u64n(u64::MAX, u64::MAX) + .wrapping_sub(u64::MAX / 2 + 1) as i64; + }); + a.encode_i64_vec(log_base2k, log_k, &have, 63); + //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); + let mut want = vec![i64::default(); n]; + //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); + a.decode_i64_vec(log_base2k, log_k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + } + + #[test] + fn test_normalize() {} +} diff --git a/base2k/src/free.rs b/base2k/src/free.rs new file mode 100644 index 0000000..3ba787a --- /dev/null +++ b/base2k/src/free.rs @@ -0,0 +1,43 @@ +use crate::ffi::svp; +use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_dft; +use crate::ffi::vmp; +use crate::{SvpPPol, VecZnxBig, VecZnxDft, VmpPMat}; + +/// This trait should be implemented by structs that point to +/// memory allocated through C. +pub trait Free { + // Frees the memory and self destructs. + fn free(self); +} + +impl Free for VmpPMat { + /// Frees the C allocated memory of the [VmpPMat] and self destructs the struct. + fn free(self) { + unsafe { vmp::delete_vmp_pmat(self.data) }; + drop(self); + } +} + +impl Free for VecZnxDft { + fn free(self) { + unsafe { vec_znx_dft::delete_vec_znx_dft(self.0) }; + drop(self); + } +} + +impl Free for VecZnxBig { + fn free(self) { + unsafe { + vec_znx_big::delete_vec_znx_big(self.0); + } + drop(self); + } +} + +impl Free for SvpPPol { + fn free(self) { + unsafe { svp::delete_svp_ppol(self.0) }; + let _ = drop(self); + } +} diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs new file mode 100644 index 0000000..554b747 --- /dev/null +++ b/base2k/src/infos.rs @@ -0,0 +1,77 @@ +use crate::{VecZnx, VmpPMat}; + +pub trait Infos { + /// Returns the ring degree of the receiver. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the receiver. + fn log_n(&self) -> usize; + + /// Returns the number of limbs of the receiver. + /// This method is equivalent to [Infos::cols]. + fn limbs(&self) -> usize; + + /// Returns the number of columns of the receiver. + /// This method is equivalent to [Infos::limbs]. + fn cols(&self) -> usize; + + /// Returns the number of rows of the receiver. + fn rows(&self) -> usize; +} + +impl Infos for VecZnx { + /// Returns the base 2 logarithm of the [VecZnx] degree. + fn log_n(&self) -> usize { + (usize::BITS - (self.n - 1).leading_zeros()) as _ + } + + /// Returns the [VecZnx] degree. + fn n(&self) -> usize { + self.n + } + + /// Returns the number of limbs of the [VecZnx]. + fn limbs(&self) -> usize { + self.data.len() / self.n + } + + /// Returns the number of limbs of the [VecZnx]. + fn cols(&self) -> usize { + self.data.len() / self.n + } + + /// Returns the number of limbs of the [VecZnx]. + fn rows(&self) -> usize { + 1 + } +} + +impl Infos for VmpPMat { + /// Returns the ring dimension of the [VmpPMat]. + fn n(&self) -> usize { + self.n + } + + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + /// Returns the number of limbs of each [VecZnxDft]. + /// This method is equivalent to [Self::cols]. + fn limbs(&self) -> usize { + self.cols + } + + /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] + fn rows(&self) -> usize { + self.rows + } + + /// Returns the number of cols of the [VmpPMat]. + /// The number of cols refers to the number of limbs + /// of each [VecZnxDft]. + /// This method is equivalent to [Self::limbs]. + fn cols(&self) -> usize { + self.cols + } +} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 4ef66ee..df2d6f3 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -11,21 +11,13 @@ pub mod module; #[allow(unused_imports)] pub use module::*; -pub mod scalar; -#[allow(unused_imports)] -pub use scalar::*; - pub mod vec_znx; #[allow(unused_imports)] pub use vec_znx::*; -pub mod vec_znx_arithmetic; +pub mod vec_znx_big; #[allow(unused_imports)] -pub use vec_znx_arithmetic::*; - -pub mod vec_znx_big_arithmetic; -#[allow(unused_imports)] -pub use vec_znx_big_arithmetic::*; +pub use vec_znx_big::*; pub mod vec_znx_dft; #[allow(unused_imports)] @@ -39,6 +31,22 @@ pub mod vmp; #[allow(unused_imports)] pub use vmp::*; +pub mod sampling; +#[allow(unused_imports)] +pub use sampling::*; + +pub mod encoding; +#[allow(unused_imports)] +pub use encoding::*; + +pub mod infos; +#[allow(unused_imports)] +pub use infos::*; + +pub mod free; +#[allow(unused_imports)] +pub use free::*; + pub const GALOISGENERATOR: u64 = 5; #[allow(dead_code)] @@ -65,10 +73,3 @@ pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] { let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } - -/// This trait should be implemented by structs that point to -/// memory allocated through C. -pub trait Free { - // Frees the memory and self destructs. - fn free(self); -} diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs new file mode 100644 index 0000000..39f024b --- /dev/null +++ b/base2k/src/sampling.rs @@ -0,0 +1,94 @@ +use crate::{Infos, VecZnx}; +use rand_distr::{Distribution, Normal}; +use sampling::source::Source; + +pub trait Sampling { + /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k}, 2^{log_base2k}\] + fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source); + + /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. + fn add_dist_f64>( + &mut self, + log_base2k: usize, + log_k: usize, + source: &mut Source, + dist: T, + bound: f64, + ); + + /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. + fn add_normal( + &mut self, + log_base2k: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); +} + +impl Sampling for VecZnx { + fn fill_uniform(&mut self, log_base2k: usize, limbs: usize, source: &mut Source) { + let base2k: u64 = 1 << log_base2k; + let mask: u64 = base2k - 1; + let base2k_half: i64 = (base2k >> 1) as i64; + + let size: usize = self.n() * (limbs - 1); + + self.data[..size] + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + } + + fn add_dist_f64>( + &mut self, + log_base2k: usize, + log_k: usize, + source: &mut Source, + dist: T, + bound: f64, + ) { + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let log_base2k_rem: usize = log_k % log_base2k; + + if log_base2k_rem != 0 { + self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem + }); + } else { + self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } + + fn add_normal( + &mut self, + log_base2k: usize, + log_k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + self.add_dist_f64( + log_base2k, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} diff --git a/base2k/src/scalar.rs b/base2k/src/scalar.rs deleted file mode 100644 index 50ff246..0000000 --- a/base2k/src/scalar.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::Module; -use rand::seq::SliceRandom; -use rand_core::RngCore; -use rand_distr::{Distribution, WeightedIndex}; -use sampling::source::Source; - -pub struct Scalar(pub Vec); - -impl Module { - pub fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } -} - -impl Scalar { - pub fn new(n: usize) -> Self { - Self(vec![i64::default(); Self::buffer_size(n)]) - } - - pub fn buffer_size(n: usize) -> usize { - n - } - - pub fn from_buffer(&mut self, n: usize, buf: &[i64]) { - let size: usize = Self::buffer_size(n); - assert!( - buf.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={})={}", - buf.len(), - n, - size - ); - self.0 = Vec::from(&buf[..size]) - } - - pub fn as_ptr(&self) -> *const i64 { - self.0.as_ptr() - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - let choices: [i64; 3] = [-1, 0, 1]; - let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; - let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); - self.0 - .iter_mut() - .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.0[..hw] - .iter_mut() - .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); - self.0.shuffle(source); - } -} diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index bac0416..95372a0 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,10 +1,65 @@ -use crate::ffi::svp::{delete_svp_ppol, new_svp_ppol, svp_apply_dft, svp_ppol_t, svp_prepare}; -use crate::scalar::Scalar; +use crate::ffi::svp; use crate::{Free, Module, VecZnx, VecZnxDft}; -pub struct SvpPPol(pub *mut svp_ppol_t, pub usize); +use crate::Infos; +use rand::seq::SliceRandom; +use rand_core::RngCore; +use rand_distr::{Distribution, WeightedIndex}; +use sampling::source::Source; -/// A prepared [crate::Scalar] for [ScalarVectorProduct::svp_apply_dft]. +pub struct Scalar(pub Vec); + +impl Module { + pub fn new_scalar(&self) -> Scalar { + Scalar::new(self.n()) + } +} + +impl Scalar { + pub fn new(n: usize) -> Self { + Self(vec![i64::default(); Self::buffer_size(n)]) + } + + pub fn buffer_size(n: usize) -> usize { + n + } + + pub fn from_buffer(&mut self, n: usize, buf: &[i64]) { + let size: usize = Self::buffer_size(n); + assert!( + buf.len() >= size, + "invalid buffer: buf.len()={} < self.buffer_size(n={})={}", + buf.len(), + n, + size + ); + self.0 = Vec::from(&buf[..size]) + } + + pub fn as_ptr(&self) -> *const i64 { + self.0.as_ptr() + } + + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + let choices: [i64; 3] = [-1, 0, 1]; + let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.0 + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.0[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); + self.0.shuffle(source); + } +} + +pub struct SvpPPol(pub *mut svp::svp_ppol_t, pub usize); + +/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. /// The backend array of an [SvpPPol] is allocated in C and must be freed manually. impl SvpPPol { @@ -19,15 +74,8 @@ impl SvpPPol { } } -impl Free for SvpPPol { - fn free(self) { - unsafe { delete_svp_ppol(self.0) }; - let _ = drop(self); - } -} - -pub trait ScalarVectorProduct { - /// Prepares a [crate::Scalar] for a [ScalarVectorProduct::svp_apply_dft]. +pub trait SvpPPolOps { + /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); /// Allocates a new [SvpPPol]. @@ -38,16 +86,16 @@ pub trait ScalarVectorProduct { fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); } -impl Module { - pub fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { - unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } +impl SvpPPolOps for Module { + fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { + unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } } - pub fn svp_new_ppol(&self) -> SvpPPol { - unsafe { SvpPPol(new_svp_ppol(self.0), self.n()) } + fn svp_new_ppol(&self) -> SvpPPol { + unsafe { SvpPPol(svp::new_svp_ppol(self.0), self.n()) } } - pub fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { + fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { let limbs: u64 = b.limbs() as u64; assert!( c.limbs() as u64 >= limbs, @@ -55,6 +103,6 @@ impl Module { c.limbs(), limbs ); - unsafe { svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) } + unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) } } } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 3e775b6..410b3f3 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,26 +1,23 @@ use crate::cast_mut_u8_to_mut_i64_slice; -use crate::ffi::znx::{ - znx_automorphism_i64, znx_automorphism_inplace_i64, znx_normalize, znx_zero_i64_ref, -}; -use crate::module::Module; +use crate::ffi::vec_znx; +use crate::ffi::znx; +use crate::{Infos, Module}; use itertools::izip; -use rand_distr::{Distribution, Normal}; -use sampling::source::Source; use std::cmp::min; -impl Module { - pub fn new_vec_znx(&self, limbs: usize) -> VecZnx { - VecZnx::new(self.n(), limbs) - } -} - +/// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. +/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array +/// in the memory. #[derive(Clone)] pub struct VecZnx { + /// Polynomial degree. pub n: usize, + /// Polynomial coefficients, as a contiguous array. Each limb is equally spaced by n. pub data: Vec, } impl VecZnx { + /// Allocates a new [VecZnx] composed of #limbs polynomials of Z\[X\]. pub fn new(n: usize, limbs: usize) -> Self { Self { n: n, @@ -28,11 +25,14 @@ impl VecZnx { } } + /// Returns the minimum size of the [i64] array required to assign a + /// new backend array to a [VecZnx] through [VecZnx::from_buffer]. pub fn buffer_size(n: usize, limbs: usize) -> usize { n * limbs } - pub fn from_buffer(&mut self, n: usize, limbs: usize, buf: &[i64]) { + /// Assigns a new backing array to a [VecZnx]. + pub fn from_buffer(&mut self, n: usize, limbs: usize, buf: &mut [i64]) { let size = Self::buffer_size(n, limbs); assert!( buf.len() >= size, @@ -46,142 +46,94 @@ impl VecZnx { self.data = Vec::from(&buf[..size]) } - pub fn log_n(&self) -> u64 { - (u64::BITS - (self.n - 1).leading_zeros()) as _ - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn limbs(&self) -> usize { - self.data.len() / self.n - } - + /// Copies the coefficients of `a` on the receiver. + /// Copy is done with the minimum size matching both backing arrays. pub fn copy_from(&mut self, a: &VecZnx) { let size = min(self.data.len(), a.data.len()); self.data[..size].copy_from_slice(&a.data[..size]) } + /// Returns a non-mutable pointer to the backing array of the [VecZnx]. pub fn as_ptr(&self) -> *const i64 { self.data.as_ptr() } + /// Returns a mutable pointer to the backing array of the [VecZnx]. pub fn as_mut_ptr(&mut self) -> *mut i64 { self.data.as_mut_ptr() } + /// Returns a non-mutable reference to the i-th limb of the [VecZnx]. pub fn at(&self, i: usize) -> &[i64] { &self.data[i * self.n..(i + 1) * self.n] } - pub fn at_ptr(&self, i: usize) -> *const i64 { - &self.data[i * self.n] as *const i64 - } - - pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { - &mut self.data[i * self.n] as *mut i64 - } - + /// Returns a mutable reference to the i-th limb of the [VecZnx]. pub fn at_mut(&mut self, i: usize) -> &mut [i64] { &mut self.data[i * self.n..(i + 1) * self.n] } + /// Returns a non-mutable pointer to the i-th limb of the [VecZnx]. + pub fn at_ptr(&self, i: usize) -> *const i64 { + &self.data[i * self.n] as *const i64 + } + + /// Returns a mutable pointer to the i-th limb of the [VecZnx]. + pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { + &mut self.data[i * self.n] as *mut i64 + } + + /// Zeroes the backing array of the [VecZnx]. pub fn zero(&mut self) { - unsafe { znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } - } - - pub fn from_i64(&mut self, log_base2k: usize, data: &[i64], log_max: usize, log_k: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; - - assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); - - let size: usize = min(data.len(), self.n()); - let log_k_rem: usize = log_base2k - (log_k % log_base2k); - - // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - (0..limbs - 1).for_each(|i| unsafe { - znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); - }); - self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); - } else { - let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - - (0..steps).for_each(|i| unsafe { - znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); - }); - - (limbs - steps..limbs) - .rev() - .enumerate() - .for_each(|(i, i_rev)| { - let shift: usize = i * log_base2k; - izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) - .for_each(|(y, x)| *y = (x >> shift) & mask); - }) - } - - // Case where self.prec % self.k != 0. - if log_k_rem != log_base2k { - let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|i| { - self.at_mut(i)[..size] - .iter_mut() - .for_each(|x| *x <<= log_k_rem); - }) - } - } - - pub fn from_i64_single( - &mut self, - log_base2k: usize, - i: usize, - value: i64, - log_max: usize, - log_k: usize, - ) { - assert!(i < self.n()); - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; - assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); - let log_k_rem: usize = log_base2k - (log_k % log_base2k); - let limbs = self.limbs(); - - // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - (0..limbs - 1).for_each(|j| self.at_mut(j)[i] = 0); - - self.at_mut(self.limbs() - 1)[i] = value; - } else { - let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - - (0..limbs - steps).for_each(|j| self.at_mut(j)[i] = 0); - - (limbs - steps..limbs) - .rev() - .enumerate() - .for_each(|(j, j_rev)| { - self.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; - }) - } - - // Case where self.prec % self.k != 0. - if log_k_rem != log_base2k { - let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|j| { - self.at_mut(j)[i] <<= log_k_rem; - }) - } + unsafe { znx::znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } } + /// Normalizes the [VecZnx], ensuring all coefficients are in the interval \[-2^log_base2k, 2^log_base2k]. + /// + /// # Arguments + /// + /// * `log_base2k`: the base two logarithm of the base to reduce to. + /// * `carry`: scratch space of size at least self.n()<<3. + /// + /// # Panics + /// + /// The method will panic if carry.len() < self.data.len()*8. + /// + /// # Example + /// ``` + /// use base2k::{VecZnx, Encoding, Infos}; + /// use itertools::izip; + /// use sampling::source::Source; + /// + /// let n: usize = 8; // polynomial degree + /// let log_base2k: usize = 17; // base two logarithm of the coefficients decomposition + /// let limbs: usize = 5; // number of limbs (i.e. can store coeffs in the range +/- 2^{limbs * log_base2k - 1}) + /// let log_k: usize = limbs * log_base2k - 5; + /// let mut a: VecZnx = VecZnx::new(n, limbs); + /// let mut carry: Vec = vec![u8::default(); a.n()<<3]; + /// let mut have: Vec = vec![i64::default(); a.n()]; + /// let mut source = Source::new([1; 32]); + /// + /// // Populates the first limb of the of polynomials with random i64 values. + /// have.iter_mut().for_each(|x| { + /// *x = source + /// .next_u64n(u64::MAX, u64::MAX) + /// .wrapping_sub(u64::MAX / 2 + 1) as i64; + /// }); + /// a.encode_i64_vec(log_base2k, log_k, &have, 63); + /// a.normalize(log_base2k, &mut carry); + /// + /// // Ensures normalized values are in the range +/- 2^{log_base2k-1} + /// let base_half = 1 << (log_base2k - 1); + /// a.data + /// .iter() + /// .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); + /// + /// // Ensures reconstructed normalized values are equal to non-normalized values. + /// let mut want = vec![i64::default(); n]; + /// a.decode_i64_vec(log_base2k, log_k, &mut want); + /// izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// ``` pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { assert!( carry.len() >= self.n * 8, @@ -193,9 +145,9 @@ impl VecZnx { let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); unsafe { - znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); (0..self.limbs()).rev().for_each(|i| { - znx_normalize( + znx::znx_normalize( self.n as u64, log_base2k as u64, self.at_mut_ptr(i), @@ -207,120 +159,116 @@ impl VecZnx { } } - pub fn to_i64(&self, log_base2k: usize, data: &mut [i64], log_k: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limb. + /// + /// # Arguments + /// + /// * `k`: the power to which to map each coefficients. + /// * `limbs`: the number of limbs on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `limbs` is greater than `self.limbs()`. + /// + /// # Example + /// ``` + /// use base2k::{VecZnx, Encoding, Infos}; + /// use itertools::izip; + /// + /// let n: usize = 8; // polynomial degree + /// let mut a: VecZnx = VecZnx::new(n, 2); + /// let mut b: VecZnx = VecZnx::new(n, 2); + /// + /// (0..a.limbs()).for_each(|i|{ + /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ + /// *x = i as i64 + /// }) + /// }); + /// + /// b.copy_from(&a); + /// + /// a.automorphism_inplace(-1, 1); // X^i -> X^(-i) + /// let limb = b.at_mut(0); + /// (1..limb.len()).for_each(|i|{ + /// limb[n-i] = -(i as i64) + /// }); + /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// ``` + pub fn automorphism_inplace(&mut self, k: i64, limbs: usize) { assert!( - data.len() >= self.n, - "invalid data: data.len()={} < self.n()={}", - data.len(), - self.n + limbs <= self.limbs(), + "invalid limbs argument: limbs={} > self.limbs()={}", + limbs, + self.limbs() ); - data.copy_from_slice(self.at(0)); - let rem: usize = log_base2k - (log_k % log_base2k); - (1..limbs).for_each(|i| { - if i == limbs - 1 && rem != log_base2k { - let k_rem: usize = log_base2k - rem; - izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << log_base2k) + x; - }); - } - }) - } - - pub fn to_i64_single(&self, log_base2k: usize, i: usize, log_k: usize) -> i64 { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; - assert!(i < self.n()); - let mut res: i64 = self.data[i]; - let rem: usize = log_base2k - (log_k % log_base2k); - (1..limbs).for_each(|i| { - let x = self.data[i * self.n]; - if i == limbs - 1 && rem != log_base2k { - let k_rem: usize = log_base2k - rem; - res = (res << k_rem) + (x >> rem); - } else { - res = (res << log_base2k) + x; - } - }); - res - } - - pub fn automorphism_inplace(&mut self, gal_el: i64) { unsafe { - (0..self.limbs()).for_each(|i| { - znx_automorphism_inplace_i64(self.n as u64, gal_el, self.at_mut_ptr(i)) - }) - } - } - pub fn automorphism(&mut self, gal_el: i64, a: &mut VecZnx) { - unsafe { - (0..self.limbs()).for_each(|i| { - znx_automorphism_i64(self.n as u64, gal_el, a.at_mut_ptr(i), self.at_ptr(i)) + (0..limbs).for_each(|i| { + znx::znx_automorphism_inplace_i64(self.n as u64, k, self.at_mut_ptr(i)) }) } } - pub fn fill_uniform(&mut self, log_base2k: usize, source: &mut Source, limbs: usize) { - let base2k: u64 = 1 << log_base2k; - let mask: u64 = base2k - 1; - let base2k_half: i64 = (base2k >> 1) as i64; - - let size: usize = self.n() * (limbs - 1); - - self.data[..size] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - } - - pub fn add_dist_f64>( - &mut self, - log_base2k: usize, - source: &mut Source, - dist: T, - bound: f64, - log_k: usize, - ) { - let log_base2k_rem: usize = log_k % log_base2k; - - if log_base2k_rem != 0 { - self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << log_base2k_rem - }); - } else { - self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limb. + /// + /// # Arguments + /// + /// * `a`: the receiver. + /// * `k`: the power to which to map each coefficients. + /// * `limbs`: the number of limbs on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `limbs` is greater than `self.limbs()` or `a.limbs()`. + /// + /// # Example + /// ``` + /// use base2k::{VecZnx, Encoding, Infos}; + /// use itertools::izip; + /// + /// let n: usize = 8; // polynomial degree + /// let mut a: VecZnx = VecZnx::new(n, 2); + /// let mut b: VecZnx = VecZnx::new(n, 2); + /// let mut c: VecZnx = VecZnx::new(n, 2); + /// + /// (0..a.limbs()).for_each(|i|{ + /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ + /// *x = i as i64 + /// }) + /// }); + /// + /// a.automorphism(&mut b, -1, 1); // X^i -> X^(-i) + /// let limb = c.at_mut(0); + /// (1..limb.len()).for_each(|i|{ + /// limb[n-i] = -(i as i64) + /// }); + /// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// ``` + pub fn automorphism(&mut self, a: &mut VecZnx, k: i64, limbs: usize) { + assert!( + limbs <= self.limbs(), + "invalid limbs argument: limbs={} > self.limbs()={}", + limbs, + self.limbs() + ); + assert!( + limbs <= a.limbs(), + "invalid limbs argument: limbs={} > a.limbs()={}", + limbs, + a.limbs() + ); + unsafe { + (0..limbs).for_each(|i| { + znx::znx_automorphism_i64(self.n as u64, k, a.at_mut_ptr(i), self.at_ptr(i)) + }) } } - pub fn add_normal( - &mut self, - log_base2k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - log_k: usize, - ) { - self.add_dist_f64( - log_base2k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - log_k, - ); - } - + /// Truncates the precision of the [VecZnx] by k bits. + /// + /// # Arguments + /// + /// * `log_base2k`: the base two logarithm of the coefficients decomposition. + /// * `k`: the number of bits of precision to drop. pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { if k == 0 { return; @@ -339,6 +287,17 @@ impl VecZnx { } } + /// Right shifts the coefficients by k bits. + /// + /// # Arguments + /// + /// * `log_base2k`: the base two logarithm of the coefficients decomposition. + /// * `k`: the shift amount. + /// * `carry`: scratch space of size at least equal to self.n() * self.limbs() << 3. + /// + /// # Panics + /// + /// The method will panic if carry.len() < self.n() * self.limbs() << 3. pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { assert!( carry.len() >> 3 >= self.n(), @@ -352,7 +311,7 @@ impl VecZnx { self.data.rotate_right(self.n * limbs_steps); unsafe { - znx_zero_i64_ref((self.n * limbs_steps) as u64, self.data.as_mut_ptr()); + znx::znx_zero_i64_ref((self.n * limbs_steps) as u64, self.data.as_mut_ptr()); } let k_rem = k % log_base2k; @@ -361,7 +320,7 @@ impl VecZnx { let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); unsafe { - znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); } let mask: i64 = (1 << k_rem) - 1; @@ -377,6 +336,12 @@ impl VecZnx { } } + /// If self.n() > a.n(): Extracts X^{i*self.n()/a.n()} -> X^{i}. + /// If self.n() < a.n(): Extracts X^{i} -> X^{i*a.n()/self.n()}. + /// + /// # Arguments + /// + /// * `a`: the receiver polynomial in which the extracted coefficients are stored. pub fn switch_degree(&self, a: &mut VecZnx) { let (n_in, n_out) = (self.n(), a.n()); let (gap_in, gap_out): (usize, usize); @@ -404,74 +369,207 @@ impl VecZnx { } } -#[cfg(test)] -mod tests { - use crate::VecZnx; - use itertools::izip; - use sampling::source::Source; +pub trait VecZnxOps { + /// Allocates a new [VecZnx]. + /// + /// # Arguments + /// + /// * `limbs`: the number of limbs. + fn new_vec_znx(&self, limbs: usize) -> VecZnx; - #[test] - fn test_set_get_i64_lo_norm() { - let n: usize = 8; - let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, limbs); - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); - a.from_i64(log_base2k, &have, 10, log_k); - let mut want = vec![i64::default(); n]; - a.to_i64(log_base2k, &mut want, log_k); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + /// c <- a + b. + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- b + a. + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// c <- a - b. + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- b - a. + fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -a. + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -b. + fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + + /// b <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + + /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); +} + +impl VecZnxOps for Module { + fn new_vec_znx(&self, limbs: usize) -> VecZnx { + VecZnx::new(self.n(), limbs) } - #[test] - fn test_set_get_i64_hi_norm() { - let n: usize = 8; - let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, limbs); - let mut have: Vec = vec![i64::default(); n]; - let mut source = Source::new([1; 32]); - have.iter_mut().for_each(|x| { - *x = source - .next_u64n(u64::MAX, u64::MAX) - .wrapping_sub(u64::MAX / 2 + 1) as i64; - }); - a.from_i64(log_base2k, &have, 63, log_k); - //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - let mut want = vec![i64::default(); n]; - //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - a.to_i64(log_base2k, &mut want, log_k); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + // c <- a + b + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + unsafe { + vec_znx::vec_znx_add( + self.0, + c.as_mut_ptr(), + c.limbs() as u64, + c.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.as_ptr(), + b.limbs() as u64, + b.n() as u64, + ) + } } - #[test] - fn test_normalize() { - let n: usize = 8; - let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, limbs); - let mut have: Vec = vec![i64::default(); n]; - let mut source = Source::new([1; 32]); - have.iter_mut().for_each(|x| { - *x = source - .next_u64n(u64::MAX, u64::MAX) - .wrapping_sub(u64::MAX / 2 + 1) as i64; - }); - a.from_i64(log_base2k, &have, 63, log_k); - let mut carry: Vec = vec![u8::default(); n * 8]; - a.normalize(log_base2k, &mut carry); - let base_half = 1 << (log_base2k - 1); - a.data - .iter() - .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); - let mut want = vec![i64::default(); n]; - a.to_i64(log_base2k, &mut want, log_k); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + // b <- a + b + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + vec_znx::vec_znx_add( + self.0, + b.as_mut_ptr(), + b.limbs() as u64, + b.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.as_ptr(), + b.limbs() as u64, + b.n() as u64, + ) + } + } + + // c <- a + b + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + unsafe { + vec_znx::vec_znx_sub( + self.0, + c.as_mut_ptr(), + c.limbs() as u64, + c.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.as_ptr(), + b.limbs() as u64, + b.n() as u64, + ) + } + } + + // b <- a + b + fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + vec_znx::vec_znx_sub( + self.0, + b.as_mut_ptr(), + b.limbs() as u64, + b.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.as_ptr(), + b.limbs() as u64, + b.n() as u64, + ) + } + } + + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + vec_znx::vec_znx_negate( + self.0, + b.as_mut_ptr(), + b.limbs() as u64, + b.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + ) + } + } + + fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + unsafe { + vec_znx::vec_znx_negate( + self.0, + a.as_mut_ptr(), + a.limbs() as u64, + a.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + ) + } + } + + fn vec_znx_rotate(&self, k: i64, a: &mut VecZnx, b: &VecZnx) { + unsafe { + vec_znx::vec_znx_rotate( + self.0, + k, + a.as_mut_ptr(), + a.limbs() as u64, + a.n() as u64, + b.as_ptr(), + b.limbs() as u64, + b.n() as u64, + ) + } + } + + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + vec_znx::vec_znx_rotate( + self.0, + k, + a.as_mut_ptr(), + a.limbs() as u64, + a.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + ) + } + } + + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + unsafe { + vec_znx::vec_znx_automorphism( + self.0, + k, + b.as_mut_ptr(), + b.limbs() as u64, + b.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + ); + } + } + + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + vec_znx::vec_znx_automorphism( + self.0, + k, + a.as_mut_ptr(), + a.limbs() as u64, + a.n() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + ); + } } } diff --git a/base2k/src/vec_znx_arithmetic.rs b/base2k/src/vec_znx_arithmetic.rs deleted file mode 100644 index 5e8bb08..0000000 --- a/base2k/src/vec_znx_arithmetic.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::ffi::vec_znx::{ - vec_znx_add, vec_znx_automorphism, vec_znx_negate, vec_znx_rotate, vec_znx_sub, -}; -use crate::{Module, VecZnx}; - -impl Module { - // c <- a + b - pub fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - unsafe { - vec_znx_add( - self.0, - c.as_mut_ptr(), - c.limbs() as u64, - c.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - b.as_ptr(), - b.limbs() as u64, - b.n() as u64, - ) - } - } - - // b <- a + b - pub fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - unsafe { - vec_znx_add( - self.0, - b.as_mut_ptr(), - b.limbs() as u64, - b.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - b.as_ptr(), - b.limbs() as u64, - b.n() as u64, - ) - } - } - - // c <- a + b - pub fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - unsafe { - vec_znx_sub( - self.0, - c.as_mut_ptr(), - c.limbs() as u64, - c.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - b.as_ptr(), - b.limbs() as u64, - b.n() as u64, - ) - } - } - - // b <- a + b - pub fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - unsafe { - vec_znx_sub( - self.0, - b.as_mut_ptr(), - b.limbs() as u64, - b.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - b.as_ptr(), - b.limbs() as u64, - b.n() as u64, - ) - } - } - - pub fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - unsafe { - vec_znx_negate( - self.0, - b.as_mut_ptr(), - b.limbs() as u64, - b.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - ) - } - } - - pub fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { - unsafe { - vec_znx_negate( - self.0, - a.as_mut_ptr(), - a.limbs() as u64, - a.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - ) - } - } - - pub fn vec_znx_rotate(&self, k: i64, a: &mut VecZnx, b: &VecZnx) { - unsafe { - vec_znx_rotate( - self.0, - k, - a.as_mut_ptr(), - a.limbs() as u64, - a.n() as u64, - b.as_ptr(), - b.limbs() as u64, - b.n() as u64, - ) - } - } - - pub fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { - unsafe { - vec_znx_rotate( - self.0, - k, - a.as_mut_ptr(), - a.limbs() as u64, - a.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - ) - } - } - - // b <- a(X^gal_el) - pub fn vec_znx_automorphism(&self, gal_el: i64, b: &mut VecZnx, a: &VecZnx) { - unsafe { - vec_znx_automorphism( - self.0, - gal_el, - b.as_mut_ptr(), - b.limbs() as u64, - b.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - ); - } - } - - // a <- a(X^gal_el) - pub fn vec_znx_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnx) { - unsafe { - vec_znx_automorphism( - self.0, - gal_el, - a.as_mut_ptr(), - a.limbs() as u64, - a.n() as u64, - a.as_ptr(), - a.limbs() as u64, - a.n() as u64, - ); - } - } -} diff --git a/base2k/src/vec_znx_big_arithmetic.rs b/base2k/src/vec_znx_big.rs similarity index 78% rename from base2k/src/vec_znx_big_arithmetic.rs rename to base2k/src/vec_znx_big.rs index 32c6e74..874c83f 100644 --- a/base2k/src/vec_znx_big_arithmetic.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,36 +1,22 @@ -use crate::ffi::vec_znx_big::{ - delete_vec_znx_big, new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism, - vec_znx_big_normalize_base2k, vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a, - vec_znx_bigcoeff_t, -}; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::Free; -use crate::{Module, VecZnx, VecZnxDft}; +use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_dft; +use crate::{Infos, Module, VecZnx, VecZnxDft}; -pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize); +pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); impl VecZnxBig { pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft(self.0 as *mut vec_znx_dft_t, self.1) + VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1) } pub fn limbs(&self) -> usize { self.1 } } -impl Free for VecZnxBig { - fn free(self) { - unsafe { - delete_vec_znx_big(self.0); - } - drop(self); - } -} - impl Module { // Allocates a vector Z[X]/(X^N+1) that stores not normalized values. pub fn new_vec_znx_big(&self, limbs: usize) -> VecZnxBig { - unsafe { VecZnxBig(new_vec_znx_big(self.0, limbs as u64), limbs) } + unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } } // b <- b - a @@ -43,7 +29,7 @@ impl Module { limbs ); unsafe { - vec_znx_big_sub_small_a( + vec_znx_big::vec_znx_big_sub_small_a( self.0, b.0, b.limbs() as u64, @@ -72,7 +58,7 @@ impl Module { limbs ); unsafe { - vec_znx_big_sub_small_a( + vec_znx_big::vec_znx_big_sub_small_a( self.0, c.0, c.limbs() as u64, @@ -101,7 +87,7 @@ impl Module { limbs ); unsafe { - vec_znx_big_add_small( + vec_znx_big::vec_znx_big_add_small( self.0, c.0, limbs as u64, @@ -124,7 +110,7 @@ impl Module { limbs ); unsafe { - vec_znx_big_add_small( + vec_znx_big::vec_znx_big_add_small( self.0, b.0, limbs as u64, @@ -138,7 +124,7 @@ impl Module { } pub fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } + unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } } // b <- normalize(a) @@ -163,7 +149,7 @@ impl Module { self.vec_znx_big_normalize_tmp_bytes() ); unsafe { - vec_znx_big_normalize_base2k( + vec_znx_big::vec_znx_big_normalize_base2k( self.0, log_base2k as u64, b.as_mut_ptr(), @@ -178,13 +164,27 @@ impl Module { pub fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { - vec_znx_big_automorphism(self.0, gal_el, b.0, b.limbs() as u64, a.0, a.limbs() as u64); + vec_znx_big::vec_znx_big_automorphism( + self.0, + gal_el, + b.0, + b.limbs() as u64, + a.0, + a.limbs() as u64, + ); } } pub fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { unsafe { - vec_znx_big_automorphism(self.0, gal_el, a.0, a.limbs() as u64, a.0, a.limbs() as u64); + vec_znx_big::vec_znx_big_automorphism( + self.0, + gal_el, + a.0, + a.limbs() as u64, + a.0, + a.limbs() as u64, + ); } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 151bb5e..74732b9 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,32 +1,22 @@ -use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t; -use crate::ffi::vec_znx_dft::{ - delete_vec_znx_dft, new_vec_znx_dft, vec_znx_dft_t, vec_znx_idft, vec_znx_idft_tmp_a, - vec_znx_idft_tmp_bytes, -}; -use crate::{Free, Module, VecZnxBig}; +use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_dft; +use crate::{Module, VecZnxBig}; -pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize); +pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); impl VecZnxDft { pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig(self.0 as *mut vec_znx_bigcoeff_t, self.1) + VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1) } pub fn limbs(&self) -> usize { self.1 } } -impl Free for VecZnxDft { - fn free(self) { - unsafe { delete_vec_znx_dft(self.0) }; - drop(self); - } -} - impl Module { // Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. pub fn new_vec_znx_dft(&self, limbs: usize) -> VecZnxDft { - unsafe { VecZnxDft(new_vec_znx_dft(self.0, limbs as u64), limbs) } + unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } } // b <- IDFT(a), uses a as scratch space. @@ -37,12 +27,12 @@ impl Module { b.limbs(), a_limbs ); - unsafe { vec_znx_idft_tmp_a(self.0, b.0, a_limbs as u64, a.0, a_limbs as u64) } + unsafe { vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, a_limbs as u64, a.0, a_limbs as u64) } } // Returns the size of the scratch space for [vec_znx_idft]. pub fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_idft_tmp_bytes(self.0) as usize } + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. @@ -72,7 +62,7 @@ impl Module { self.vec_znx_idft_tmp_bytes() ); unsafe { - vec_znx_idft( + vec_znx_dft::vec_znx_idft( self.0, b_vector.0, a_limbs as u64, diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 1260c0d..a228d54 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,10 +1,5 @@ -use crate::ffi::vmp::{ - delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, - vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, - vmp_prepare_contiguous_tmp_bytes, -}; -use crate::Free; -use crate::{Module, VecZnx, VecZnxDft}; +use crate::ffi::vmp; +use crate::{Infos, Module, VecZnx, VecZnxDft}; use std::cmp::min; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -15,10 +10,10 @@ use std::cmp::min; /// and thus must be manually freed. /// /// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. -/// See the trait [VectorMatrixProduct] for additional information. +/// See the trait [VmpPMatOps] for additional information. pub struct VmpPMat { /// The pointer to the C memory. - pub data: *mut vmp_pmat_t, + pub data: *mut vmp::vmp_pmat_t, /// The number of [VecZnxDft]. pub rows: usize, /// The number of limbs in each [VecZnxDft]. @@ -29,31 +24,18 @@ pub struct VmpPMat { impl VmpPMat { /// Returns the pointer to the [vmp_pmat_t]. - pub fn data(&self) -> *mut vmp_pmat_t { + pub fn data(&self) -> *mut vmp::vmp_pmat_t { self.data } - /// Returns the number of rows of the [VmpPMat]. - /// The number of rows (i.e. of [VecZnx]) of the [VmpPMat]. - pub fn rows(&self) -> usize { - self.rows - } - - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of limbs - /// of the prepared [VecZnx]. - pub fn cols(&self) -> usize { - self.cols - } - - /// Returns the ring dimension of the [VmpPMat]. - pub fn n(&self) -> usize { - self.n - } - /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. /// When using [`crate::FFT64`] as backend, `T` should be [f64]. /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// + /// # Arguments + /// + /// * `row`: row index (i). + /// * `col`: col index (j). pub fn at(&self, row: usize, col: usize) -> Vec { let mut res: Vec = vec![T::default(); self.n]; @@ -86,7 +68,7 @@ impl VmpPMat { } } - /// Returns a non-mutable reference of [T] of the entire contiguous array of the [VmpPMat]. + /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. /// When using [`crate::FFT64`] as backend, `T` should be [f64]. /// When using [`crate::NTT120`] as backend, `T` should be [i64]. /// The length of the returned array is rows * cols * n. @@ -97,32 +79,38 @@ impl VmpPMat { } } -impl Free for VmpPMat { - fn free(self) { - unsafe { delete_vmp_pmat(self.data) }; - drop(self); - } -} - /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VectorMatrixProduct { +pub trait VmpPMatOps { /// Allocates a new [VmpPMat] with the given number of rows and columns. + /// + /// # Arguments + /// + /// * `rows`: number of rows (number of [VecZnxDft]). + /// * `cols`: number of cols (number of limbs of each [VecZnxDft]). fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat; - /// Returns the number of bytes needed as scratch space for [VectorMatrixProduct::vmp_prepare_contiguous]. + /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. + /// + /// # Arguments + /// + /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. + /// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize; /// Prepares a [VmpPMat] from a contiguous array of [i64]. /// The helper struct [Matrix3D] can be used to contruct and populate /// the appropriate contiguous array. /// - /// The size of buf can be obtained with [VectorMatrixProduct::vmp_prepare_contiguous_tmp_bytes]. + /// # Arguments + /// + /// * `b`: [VmpPMat] on which the values are encoded. + /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. + /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_contiguous_tmp_bytes]. /// /// # Example /// ``` - /// use base2k::{Module, Matrix3D, VmpPMat, FFT64, Free}; - /// use base2k::vmp::VectorMatrixProduct; + /// use base2k::{Module, Matrix3D, VmpPMat, VmpPMatOps, FFT64, Free}; /// use std::cmp::min; /// /// let n: usize = 1024; @@ -148,12 +136,17 @@ pub trait VectorMatrixProduct { /// Prepares a [VmpPMat] from a vector of [VecZnx]. /// - /// The size of buf can be obtained with [VectorMatrixProduct::vmp_prepare_contiguous_tmp_bytes]. + /// # Arguments + /// + /// * `b`: [VmpPMat] on which the values are encoded. + /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. + /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_contiguous_tmp_bytes]. + /// + /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_contiguous_tmp_bytes]. /// /// # Example /// ``` - /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VecZnx, Free}; - /// use base2k::vmp::VectorMatrixProduct; + /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; /// use std::cmp::min; /// /// let n: usize = 1024; @@ -176,7 +169,14 @@ pub trait VectorMatrixProduct { /// ``` fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [VectorMatrixProduct::vmp_apply_dft]. + /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. + /// + /// # Arguments + /// + /// * `c_limbs`: number of limbs of the output [VecZnxDft]. + /// * `a_limbs`: number of limbs of the input [VecZnx]. + /// * `rows`: number of rows of the input [VmpPMat]. + /// * `cols`: number of cols of the input [VmpPMat]. fn vmp_apply_dft_tmp_bytes( &self, c_limbs: usize, @@ -186,9 +186,8 @@ pub trait VectorMatrixProduct { ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. /// - /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// @@ -204,10 +203,16 @@ pub trait VectorMatrixProduct { /// ``` /// where each element is a [VecZnxDft]. /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnx] of the vector matrix product. + /// * `b`: the right operand [VmpPMat] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. + /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; - /// use base2k::vmp::VectorMatrixProduct; + /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; /// /// let n = 1024; /// @@ -233,7 +238,14 @@ pub trait VectorMatrixProduct { /// ``` fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [VectorMatrixProduct::vmp_apply_dft_to_dft]. + /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. + /// + /// # Arguments + /// + /// * `c_limbs`: number of limbs of the output [VecZnxDft]. + /// * `a_limbs`: number of limbs of the input [VecZnxDft]. + /// * `rows`: number of rows of the input [VmpPMat]. + /// * `cols`: number of cols of the input [VmpPMat]. fn vmp_apply_dft_to_dft_tmp_bytes( &self, c_limbs: usize, @@ -243,9 +255,9 @@ pub trait VectorMatrixProduct { ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. + /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// - /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// @@ -261,10 +273,16 @@ pub trait VectorMatrixProduct { /// ``` /// where each element is a [VecZnxDft]. /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [VmpPMat] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; - /// use base2k::vmp::VectorMatrixProduct; + /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; /// /// let n = 1024; /// @@ -275,7 +293,7 @@ pub trait VectorMatrixProduct { /// let cols: usize = limbs + 1; /// let c_limbs: usize = cols; /// let a_limbs: usize = limbs; - /// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_limbs, a_limbs, rows, cols); + /// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_limbs, a_limbs, rows, cols); /// /// let mut buf: Vec = vec![0; tmp_bytes]; /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); @@ -292,9 +310,9 @@ pub trait VectorMatrixProduct { fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. - /// The size of `buf` is given by [VectorMatrixProduct::vmp_apply_dft_to_dft_tmp_bytes]. + /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// - /// A vector matrix product is equivalent to a sum of [ScalarVectorProduct::svp_apply_dft] + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// @@ -310,10 +328,15 @@ pub trait VectorMatrixProduct { /// ``` /// where each element is a [VecZnxDft]. /// + /// # Arguments + /// + /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the right operand [VmpPMat] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, FFT64, Free}; - /// use base2k::vmp::VectorMatrixProduct; + /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; /// /// let n = 1024; /// @@ -322,7 +345,7 @@ pub trait VectorMatrixProduct { /// /// let rows: usize = limbs; /// let cols: usize = limbs + 1; - /// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(limbs, limbs, rows, cols); + /// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols); /// /// let mut buf: Vec = vec![0; tmp_bytes]; /// let a: VecZnx = module.new_vec_znx(limbs); @@ -338,24 +361,25 @@ pub trait VectorMatrixProduct { fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); } -impl VectorMatrixProduct for Module { +impl VmpPMatOps for Module { fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { unsafe { VmpPMat { - data: new_vmp_pmat(self.0, rows as u64, cols as u64), + data: vmp::new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols, n: self.n(), } } } + fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } + unsafe { vmp::vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } } fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { unsafe { - vmp_prepare_contiguous( + vmp::vmp_prepare_contiguous( self.0, b.data(), a.as_ptr(), @@ -402,7 +426,7 @@ impl VectorMatrixProduct for Module { cols: usize, ) -> usize { unsafe { - vmp_apply_dft_tmp_bytes( + vmp::vmp_apply_dft_tmp_bytes( self.0, c_limbs as u64, a_limbs as u64, @@ -414,7 +438,7 @@ impl VectorMatrixProduct for Module { fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) { unsafe { - vmp_apply_dft( + vmp::vmp_apply_dft( self.0, c.0, c.limbs() as u64, @@ -437,7 +461,7 @@ impl VectorMatrixProduct for Module { cols: usize, ) -> usize { unsafe { - vmp_apply_dft_to_dft_tmp_bytes( + vmp::vmp_apply_dft_to_dft_tmp_bytes( self.0, c_limbs as u64, a_limbs as u64, @@ -449,7 +473,7 @@ impl VectorMatrixProduct for Module { fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]) { unsafe { - vmp_apply_dft_to_dft( + vmp::vmp_apply_dft_to_dft( self.0, c.0, c.limbs() as u64, @@ -465,7 +489,7 @@ impl VectorMatrixProduct for Module { fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { unsafe { - vmp_apply_dft_to_dft( + vmp::vmp_apply_dft_to_dft( self.0, b.0, b.limbs() as u64, @@ -481,7 +505,7 @@ impl VectorMatrixProduct for Module { } /// A helper struture that stores a 3D matrix as a contiguous array. -/// To be passed to [VectorMatrixProduct::vmp_prepare_contiguous]. +/// To be passed to [VmpPMatOps::vmp_prepare_contiguous]. /// /// rows: index of the i-th base2K power. /// cols: index of the j-th limb of the i-th row. @@ -498,6 +522,12 @@ pub struct Matrix3D { impl Matrix3D { /// Allocates a new [Matrix3D] with the respective dimensions. /// + /// # Arguments + /// + /// * `rows`: the number of rows of the matrix. + /// * `cols`: the number of cols of the matrix. + /// # `n`: the size of each entry of the matrix. + /// /// # Example /// ``` /// use base2k::Matrix3D; @@ -521,6 +551,11 @@ impl Matrix3D { /// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D]. /// The returned array is of size n. /// + /// # Arguments + /// + /// * `row`: the index of the row. + /// * `col`: the index of the col. + /// /// # Example /// ``` /// use base2k::Matrix3D; @@ -542,6 +577,11 @@ impl Matrix3D { /// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D]. /// The returned array is of size n. /// + /// # Arguments + /// + /// * `row`: the index of the row. + /// * `col`: the index of the col. + /// /// # Example /// ``` /// use base2k::Matrix3D; @@ -564,6 +604,11 @@ impl Matrix3D { /// Typicall this is used to assign a [VecZnx] to the i-th row /// of the [Matrix3D]. /// + /// # Arguments + /// + /// * `row`: the index of the row. + /// * `a`: the data to encode onthe row. + /// /// # Example /// ``` /// use base2k::{Matrix3D, VecZnx};