diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index be432fb..f2ffbe4 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, - VecZnxOps, FFT64, + Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, + VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, FFT64, }; use itertools::izip; use sampling::source::Source; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index e4e1e29..7dd9a59 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, Free, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMat, VmpPMatOps, - FFT64, + Encoding, Free, Infos, Module, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, FFT64, }; fn main() { @@ -40,10 +40,8 @@ fn main() { vecznx[i].data[i * n + 1] = 1 as i64; }); - let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect(); - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf); + module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); @@ -65,47 +63,3 @@ fn main() { //println!("{:?}", values_res) } - -/* - -use base2k::{ - Encoding, Free, Infos, Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMat, - VmpPMatOps, FFT64, -}; -use std::cmp::min; - -fn main() { - use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; - use std::cmp::min; - - let n: usize = 32; - let module: Module = Module::new::(n); - let rows: usize = 5; - let cols: usize = 6; - - let mut vecznx: Vec= Vec::new(); - (0..rows).for_each(|_|{ - vecznx.push(module.new_vec_znx(cols)); - }); - - (0..rows).for_each(|i|{ - vecznx[i].data[i*n] = 1 as i64; - vecznx[i].print_limbs(cols, n); - }); - - let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; - - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - - println!("123"); - - module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); - - - module.vmp_apply_dft(c, a, b, buf); - - vmp_pmat.free(); - module.free(); -} - -*/ diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 1808fd4..fcabc49 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx}; +use crate::{Infos, VecZnx, VecZnxApi}; use itertools::izip; use std::cmp::min; diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs index 554b747..2d0441f 100644 --- a/base2k/src/infos.rs +++ b/base2k/src/infos.rs @@ -1,4 +1,4 @@ -use crate::{VecZnx, VmpPMat}; +use crate::{VecZnx, VecZnxBorrow, VmpPMat}; pub trait Infos { /// Returns the ring degree of the receiver. @@ -46,6 +46,33 @@ impl Infos for VecZnx { } } +impl Infos for VecZnxBorrow { + /// Returns the base 2 logarithm of the [VecZnx] degree. + fn log_n(&self) -> usize { + (usize::BITS - (self.n - 1).leading_zeros()) as _ + } + + /// Returns the [VecZnx] degree. + fn n(&self) -> usize { + self.n + } + + /// Returns the number of limbs of the [VecZnx]. + fn limbs(&self) -> usize { + self.limbs + } + + /// Returns the number of limbs of the [VecZnx]. + fn cols(&self) -> usize { + self.limbs + } + + /// Returns the number of limbs of the [VecZnx]. + fn rows(&self) -> usize { + 1 + } +} + impl Infos for VmpPMat { /// Returns the ring dimension of the [VmpPMat]. fn n(&self) -> usize { diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 769858b..7db95ce 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -1,3 +1,4 @@ +pub mod encoding; #[allow( non_camel_case_types, non_snake_case, @@ -5,75 +6,47 @@ dead_code, improper_ctypes )] +// Other modules and exports pub mod ffi; - -pub mod module; -#[allow(unused_imports)] -pub use module::*; - -pub mod vec_znx; -#[allow(unused_imports)] -pub use vec_znx::*; - -pub mod vec_znx_big; -#[allow(unused_imports)] -pub use vec_znx_big::*; - -pub mod vec_znx_dft; -#[allow(unused_imports)] -pub use vec_znx_dft::*; - -pub mod svp; -#[allow(unused_imports)] -pub use svp::*; - -pub mod vmp; -#[allow(unused_imports)] -pub use vmp::*; - -pub mod sampling; -#[allow(unused_imports)] -pub use sampling::*; - -pub mod encoding; -#[allow(unused_imports)] -pub use encoding::*; - -pub mod infos; -#[allow(unused_imports)] -pub use infos::*; - pub mod free; -#[allow(unused_imports)] +pub mod infos; +pub mod module; +pub mod sampling; +pub mod svp; +pub mod vec_znx; +pub mod vec_znx_big; +pub mod vec_znx_dft; +pub mod vmp; + +pub use encoding::*; pub use free::*; +pub use infos::*; +pub use module::*; +pub use sampling::*; +pub use svp::*; +pub use vec_znx::*; +pub use vec_znx_big::*; +pub use vec_znx_dft::*; +pub use vmp::*; pub const GALOISGENERATOR: u64 = 5; -#[allow(dead_code)] -pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { - let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; - let len: usize = data.len() * std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } +fn is_aligned(ptr: *const T, align: usize) -> bool { + (ptr as usize) % align == 0 } -pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { - let ptr: *mut i64 = data.as_mut_ptr() as *mut i64; - let len: usize = data.len() / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } -} - -pub fn cast_mut_u8_to_mut_f64_slice(data: &mut [u8]) -> &mut [f64] { - let ptr: *mut f64 = data.as_mut_ptr() as *mut f64; - let len: usize = data.len() / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } -} - -pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] { - let ptr: *const f64 = data.as_mut_ptr() as *const f64; - let len: usize = data.len() / std::mem::size_of::(); +pub fn cast(data: &[T]) -> &[V] { + let ptr: *const V = data.as_ptr() as *const V; + let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } +pub fn cast_mut(data: &[T]) -> &mut [V] { + let ptr: *mut V = data.as_ptr() as *mut V; + let len: usize = data.len() / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} + use std::alloc::{alloc, Layout}; pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec { @@ -116,8 +89,10 @@ pub fn alloc_aligned(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } -fn alias_mut_slice_to_vec(slice: &mut [T]) -> Vec { - let ptr = slice.as_mut_ptr(); - let len = slice.len(); - unsafe { Vec::from_raw_parts(ptr, len, len) } +fn alias_mut_slice_to_vec(slice: &[T]) -> Vec { + unsafe { + let ptr: *mut T = slice.as_ptr() as *mut T; + let len: usize = slice.len(); + Vec::from_raw_parts(ptr, len, len) + } } diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 6d1ce47..6a6e528 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -61,3 +61,26 @@ impl Free for Module { drop(self); } } + +pub trait Bytes { + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. + fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. + fn vec_znx_idft_tmp_bytes(&self) -> usize; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `limbs`: the number of limbs of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize; + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; +} diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 14e5dd6..bfe1ec3 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Infos, VecZnx}; +use crate::{Infos, VecZnx, VecZnxApi}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index ee5cc3b..eab573d 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,7 +1,7 @@ use crate::ffi::svp; -use crate::{alias_mut_slice_to_vec, Module, VecZnx, VecZnxDft}; +use crate::{alias_mut_slice_to_vec, Module, VecZnxApi, VecZnxDft}; -use crate::{alloc_aligned, cast_mut_u8_to_mut_i64_slice, Infos}; +use crate::{alloc_aligned, cast, Infos}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, WeightedIndex}; @@ -37,7 +37,7 @@ impl Scalar { n, size ); - self.0 = alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])) + self.0 = alias_mut_slice_to_vec(cast::(&buf[..size])) } pub fn as_ptr(&self) -> *const i64 { @@ -96,7 +96,13 @@ pub trait SvpPPolOps { /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize); + fn svp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &SvpPPol, + b: &T, + b_limbs: usize, + ); } impl SvpPPolOps for Module { @@ -112,7 +118,13 @@ impl SvpPPolOps for Module { unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize) { + fn svp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &SvpPPol, + b: &T, + b_limbs: usize, + ) { assert!( c.limbs() >= b_limbs, "invalid c_vector: c_vector.limbs()={} < b.limbs()={}", diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 615cb35..fecd991 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,11 +1,210 @@ -use crate::cast_mut_u8_to_mut_i64_slice; +use crate::cast_mut; use crate::ffi::vec_znx; use crate::ffi::znx; +use crate::ffi::znx::znx_zero_i64_ref; use crate::{alias_mut_slice_to_vec, alloc_aligned}; use crate::{Infos, Module}; use itertools::izip; use std::cmp::min; +pub trait VecZnxApi { + /// Returns the minimum size of the [u8] array required to assign a + /// new backend array to a [VecZnx] through [VecZnx::from_bytes]. + fn bytes_of(n: usize, limbs: usize) -> usize; + /// Returns a new struct implementing [VecZnxApi] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[VecZnx::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the size of data is at least equal to [Module::bytes_of_vec_znx]. + fn from_bytes(n: usize, limbs: usize, bytes: &mut [u8]) -> impl VecZnxApi; + fn as_ptr(&self) -> *const i64; + fn as_mut_ptr(&mut self) -> *mut i64; + fn at(&self, i: usize) -> &[i64]; + fn at_mut(&mut self, i: usize) -> &mut [i64]; + fn at_ptr(&self, i: usize) -> *const i64; + fn at_mut_ptr(&mut self, i: usize) -> *mut i64; + fn zero(&mut self); + fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]); +} + +pub fn bytes_of_vec_znx(n: usize, limbs: usize) -> usize { + n * limbs * 8 +} + +pub struct VecZnxBorrow { + pub n: usize, + pub limbs: usize, + pub data: *mut i64, +} + +impl VecZnxApi for VecZnxBorrow { + fn bytes_of(n: usize, limbs: usize) -> usize { + bytes_of_vec_znx(n, limbs) + } + + fn from_bytes(n: usize, limbs: usize, bytes: &mut [u8]) -> impl VecZnxApi { + let size = Self::bytes_of(n, limbs); + assert!( + bytes.len() >= size, + "invalid buffer: buf.len()={} < self.buffer_size(n={}, limbs={})={}", + bytes.len(), + n, + limbs, + size + ); + VecZnxBorrow { + n: n, + limbs: limbs, + data: cast_mut(&mut bytes[..size]).as_mut_ptr(), + } + } + + fn as_ptr(&self) -> *const i64 { + self.data + } + + fn as_mut_ptr(&mut self) -> *mut i64 { + self.data + } + + fn at(&self, i: usize) -> &[i64] { + unsafe { std::slice::from_raw_parts(self.data.wrapping_add(self.n * i), self.n) } + } + + fn at_mut(&mut self, i: usize) -> &mut [i64] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n) } + } + + fn at_ptr(&self, i: usize) -> *const i64 { + self.data.wrapping_add(self.n * i) + } + + fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { + self.data.wrapping_add(self.n * i) + } + + fn zero(&mut self) { + unsafe { + znx_zero_i64_ref((self.n * self.limbs) as u64, self.data); + } + } + + fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { + assert!( + carry.len() >= self.n() * 8, + "invalid carry: carry.len()={} < self.n()={}", + carry.len(), + self.n() + ); + + let carry_i64: &mut [i64] = cast_mut(carry); + + unsafe { + znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); + (0..self.limbs()).rev().for_each(|i| { + znx::znx_normalize( + self.n as u64, + log_base2k as u64, + self.at_mut_ptr(i), + carry_i64.as_mut_ptr(), + self.at_mut_ptr(i), + carry_i64.as_mut_ptr(), + ) + }); + } + } +} + +impl VecZnxApi for VecZnx { + fn bytes_of(n: usize, limbs: usize) -> usize { + bytes_of_vec_znx(n, limbs) + } + + /// Returns a new struct implementing [VecZnxApi] with the provided data as backing array. + /// + /// The struct will take ownership of buf[..[VecZnx::bytes_of]] + /// + /// User must ensure that data is properly alligned and that + /// the size of data is at least equal to [Module::bytes_of_vec_znx]. + fn from_bytes(n: usize, limbs: usize, buf: &mut [u8]) -> impl VecZnxApi { + let size = Self::bytes_of(n, limbs); + assert!( + buf.len() >= size, + "invalid buffer: buf.len()={} < self.buffer_size(n={}, limbs={})={}", + buf.len(), + n, + limbs, + size + ); + + VecZnx { + n: n, + data: alias_mut_slice_to_vec(cast_mut(&mut buf[..size])), + } + } + + /// Returns a non-mutable pointer to the backing array of the [VecZnx]. + fn as_ptr(&self) -> *const i64 { + self.data.as_ptr() + } + + /// Returns a mutable pointer to the backing array of the [VecZnx]. + fn as_mut_ptr(&mut self) -> *mut i64 { + self.data.as_mut_ptr() + } + + /// Returns a non-mutable reference to the i-th limb of the [VecZnx]. + fn at(&self, i: usize) -> &[i64] { + &self.data[i * self.n..(i + 1) * self.n] + } + + /// Returns a mutable reference to the i-th limb of the [VecZnx]. + fn at_mut(&mut self, i: usize) -> &mut [i64] { + &mut self.data[i * self.n..(i + 1) * self.n] + } + + /// Returns a non-mutable pointer to the i-th limb of the [VecZnx]. + fn at_ptr(&self, i: usize) -> *const i64 { + &self.data[i * self.n] as *const i64 + } + + /// Returns a mutable pointer to the i-th limb of the [VecZnx]. + fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { + &mut self.data[i * self.n] as *mut i64 + } + + /// Zeroes the backing array of the [VecZnx]. + fn zero(&mut self) { + unsafe { znx::znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } + } + + fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { + assert!( + carry.len() >= self.n() * 8, + "invalid carry: carry.len()={} < self.n()={}", + carry.len(), + self.n() + ); + + let carry_i64: &mut [i64] = cast_mut(carry); + + unsafe { + znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); + (0..self.limbs()).rev().for_each(|i| { + znx::znx_normalize( + self.n as u64, + log_base2k as u64, + self.at_mut_ptr(i), + carry_i64.as_mut_ptr(), + self.at_mut_ptr(i), + carry_i64.as_mut_ptr(), + ) + }); + } + } +} + /// [VecZnx] represents a vector of small norm polynomials of Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array /// in the memory. @@ -26,32 +225,6 @@ impl VecZnx { } } - /// Returns the minimum size of the [u8] array required to assign a - /// new backend array to a [VecZnx] through [VecZnx::from_bytes]. - pub fn bytes(n: usize, limbs: usize) -> usize { - n * limbs * 8 - } - - /// Returns a new [VecZnx] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx]. - pub fn from_bytes(n: usize, limbs: usize, buf: &mut [u8]) -> VecZnx { - let size = Self::bytes(n, limbs); - assert!( - buf.len() >= size, - "invalid buffer: buf.len()={} < self.buffer_size(n={}, limbs={})={}", - buf.len(), - n, - limbs, - size - ); - - VecZnx { - n: n, - data: alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])), - } - } - /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. pub fn copy_from(&mut self, a: &VecZnx) { @@ -59,216 +232,6 @@ impl VecZnx { self.data[..size].copy_from_slice(&a.data[..size]) } - /// Returns a non-mutable pointer to the backing array of the [VecZnx]. - pub fn as_ptr(&self) -> *const i64 { - self.data.as_ptr() - } - - /// Returns a mutable pointer to the backing array of the [VecZnx]. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.data.as_mut_ptr() - } - - /// Returns a non-mutable reference to the i-th limb of the [VecZnx]. - pub fn at(&self, i: usize) -> &[i64] { - &self.data[i * self.n..(i + 1) * self.n] - } - - /// Returns a mutable reference to the i-th limb of the [VecZnx]. - pub fn at_mut(&mut self, i: usize) -> &mut [i64] { - &mut self.data[i * self.n..(i + 1) * self.n] - } - - /// Returns a non-mutable pointer to the i-th limb of the [VecZnx]. - pub fn at_ptr(&self, i: usize) -> *const i64 { - &self.data[i * self.n] as *const i64 - } - - /// Returns a mutable pointer to the i-th limb of the [VecZnx]. - pub fn at_mut_ptr(&mut self, i: usize) -> *mut i64 { - &mut self.data[i * self.n] as *mut i64 - } - - /// Zeroes the backing array of the [VecZnx]. - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref(self.data.len() as u64, self.data.as_mut_ptr()) } - } - - /// Normalizes the [VecZnx], ensuring all coefficients are in the interval \[-2^log_base2k, 2^log_base2k]. - /// - /// # Arguments - /// - /// * `log_base2k`: the base two logarithm of the base to reduce to. - /// * `carry`: scratch space of size at least self.n()<<3. - /// - /// # Panics - /// - /// The method will panic if carry.len() < self.data.len()*8. - /// - /// # Example - /// ``` - /// use base2k::{VecZnx, Encoding, Infos, alloc_aligned}; - /// use itertools::izip; - /// use sampling::source::Source; - /// - /// let n: usize = 8; // polynomial degree - /// let log_base2k: usize = 17; // base two logarithm of the coefficients decomposition - /// let limbs: usize = 5; // number of limbs (i.e. can store coeffs in the range +/- 2^{limbs * log_base2k - 1}) - /// let log_k: usize = limbs * log_base2k - 5; - /// let mut a: VecZnx = VecZnx::new(n, limbs); - /// let mut carry: Vec = alloc_aligned::(a.n()<<3, 64); - /// let mut have: Vec = alloc_aligned::(a.n(), 64); - /// let mut source = Source::new([1; 32]); - /// - /// // Populates the first limb of the of polynomials with random i64 values. - /// have.iter_mut().for_each(|x| { - /// *x = source - /// .next_u64n(u64::MAX, u64::MAX) - /// .wrapping_sub(u64::MAX / 2 + 1) as i64; - /// }); - /// a.encode_vec_i64(log_base2k, log_k, &have, 63); - /// a.normalize(log_base2k, &mut carry); - /// - /// // Ensures normalized values are in the range +/- 2^{log_base2k-1} - /// let base_half = 1 << (log_base2k - 1); - /// a.data - /// .iter() - /// .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); - /// - /// // Ensures reconstructed normalized values are equal to non-normalized values. - /// let mut want = alloc_aligned::(a.n(), 64); - /// a.decode_vec_i64(log_base2k, log_k, &mut want); - /// izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - /// ``` - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - assert!( - carry.len() >= self.n * 8, - "invalid carry: carry.len()={} < self.n()={}", - carry.len(), - self.n() - ); - - let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); - - unsafe { - znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); - (0..self.limbs()).rev().for_each(|i| { - znx::znx_normalize( - self.n as u64, - log_base2k as u64, - self.at_mut_ptr(i), - carry_i64.as_mut_ptr(), - self.at_mut_ptr(i), - carry_i64.as_mut_ptr(), - ) - }); - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limb. - /// - /// # Arguments - /// - /// * `k`: the power to which to map each coefficients. - /// * `limbs`: the number of limbs on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `limbs` is greater than `self.limbs()`. - /// - /// # Example - /// ``` - /// use base2k::{VecZnx, Encoding, Infos}; - /// use itertools::izip; - /// - /// let n: usize = 8; // polynomial degree - /// let mut a: VecZnx = VecZnx::new(n, 2); - /// let mut b: VecZnx = VecZnx::new(n, 2); - /// - /// (0..a.limbs()).for_each(|i|{ - /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ - /// *x = i as i64 - /// }) - /// }); - /// - /// b.copy_from(&a); - /// - /// a.automorphism_inplace(-1, 1); // X^i -> X^(-i) - /// let limb = b.at_mut(0); - /// (1..limb.len()).for_each(|i|{ - /// limb[n-i] = -(i as i64) - /// }); - /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - /// ``` - pub fn automorphism_inplace(&mut self, k: i64, limbs: usize) { - assert!( - limbs <= self.limbs(), - "invalid limbs argument: limbs={} > self.limbs()={}", - limbs, - self.limbs() - ); - unsafe { - (0..limbs).for_each(|i| { - znx::znx_automorphism_inplace_i64(self.n as u64, k, self.at_mut_ptr(i)) - }) - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limb. - /// - /// # Arguments - /// - /// * `a`: the receiver. - /// * `k`: the power to which to map each coefficients. - /// * `limbs`: the number of limbs on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `limbs` is greater than `self.limbs()` or `a.limbs()`. - /// - /// # Example - /// ``` - /// use base2k::{VecZnx, Encoding, Infos}; - /// use itertools::izip; - /// - /// let n: usize = 8; // polynomial degree - /// let mut a: VecZnx = VecZnx::new(n, 2); - /// let mut b: VecZnx = VecZnx::new(n, 2); - /// let mut c: VecZnx = VecZnx::new(n, 2); - /// - /// (0..a.limbs()).for_each(|i|{ - /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ - /// *x = i as i64 - /// }) - /// }); - /// - /// a.automorphism(&mut b, -1, 1); // X^i -> X^(-i) - /// let limb = c.at_mut(0); - /// (1..limb.len()).for_each(|i|{ - /// limb[n-i] = -(i as i64) - /// }); - /// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - /// ``` - pub fn automorphism(&mut self, a: &mut VecZnx, k: i64, limbs: usize) { - assert!( - limbs <= self.limbs(), - "invalid limbs argument: limbs={} > self.limbs()={}", - limbs, - self.limbs() - ); - assert!( - limbs <= a.limbs(), - "invalid limbs argument: limbs={} > a.limbs()={}", - limbs, - a.limbs() - ); - unsafe { - (0..limbs).for_each(|i| { - znx::znx_automorphism_i64(self.n as u64, k, a.at_mut_ptr(i), self.at_ptr(i)) - }) - } - } - /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -305,11 +268,13 @@ impl VecZnx { /// /// The method will panic if carry.len() < self.n() * self.limbs() << 3. pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + let n: usize = self.n(); + assert!( - carry.len() >> 3 >= self.n(), + carry.len() >> 3 >= n, "invalid carry: carry.len()/8={} < self.n()={}", carry.len() >> 3, - self.n() + n ); let limbs: usize = self.limbs(); @@ -323,10 +288,10 @@ impl VecZnx { let k_rem = k % log_base2k; if k_rem != 0 { - let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); + let carry_i64: &mut [i64] = cast_mut(carry); unsafe { - znx::znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); + znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); } let mask: i64 = (1 << k_rem) - 1; @@ -388,34 +353,34 @@ pub trait VecZnxOps { fn bytes_of_vec_znx(&self, limbs: usize) -> usize; /// c <- a + b. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + fn vec_znx_add(&self, c: &mut T, a: &T, b: &T); /// b <- b + a. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_add_inplace(&self, b: &mut T, a: &T); /// c <- a - b. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + fn vec_znx_sub(&self, c: &mut T, a: &T, b: &T); /// b <- b - a. - fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_sub_inplace(&self, b: &mut T, a: &T); /// b <- -a. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_negate(&self, b: &mut T, a: &T); /// b <- -b. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + fn vec_znx_negate_inplace(&self, a: &mut T); /// b <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_rotate(&self, k: i64, b: &mut T, a: &T); /// a <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut T); /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + fn vec_znx_automorphism(&self, k: i64, b: &mut T, a: &T, a_limbs: usize); /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut T, a_limbs: usize); /// Splits b into subrings and copies them them into a. /// @@ -444,7 +409,7 @@ impl VecZnxOps for Module { } // c <- a + b - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + fn vec_znx_add(&self, c: &mut T, a: &T, b: &T) { unsafe { vec_znx::vec_znx_add( self.0, @@ -462,7 +427,7 @@ impl VecZnxOps for Module { } // b <- a + b - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_add_inplace(&self, b: &mut T, a: &T) { unsafe { vec_znx::vec_znx_add( self.0, @@ -480,7 +445,7 @@ impl VecZnxOps for Module { } // c <- a + b - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + fn vec_znx_sub(&self, c: &mut T, a: &T, b: &T) { unsafe { vec_znx::vec_znx_sub( self.0, @@ -498,7 +463,7 @@ impl VecZnxOps for Module { } // b <- a + b - fn vec_znx_sub_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_sub_inplace(&self, b: &mut T, a: &T) { unsafe { vec_znx::vec_znx_sub( self.0, @@ -515,7 +480,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { + fn vec_znx_negate(&self, b: &mut T, a: &T) { unsafe { vec_znx::vec_znx_negate( self.0, @@ -529,7 +494,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + fn vec_znx_negate_inplace(&self, a: &mut T) { unsafe { vec_znx::vec_znx_negate( self.0, @@ -543,7 +508,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate(&self, k: i64, a: &mut VecZnx, b: &VecZnx) { + fn vec_znx_rotate(&self, k: i64, a: &mut T, b: &T) { unsafe { vec_znx::vec_znx_rotate( self.0, @@ -558,7 +523,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut T) { unsafe { vec_znx::vec_znx_rotate( self.0, @@ -573,7 +538,47 @@ impl VecZnxOps for Module { } } - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limbs. + /// + /// # Arguments + /// + /// * `a`: input. + /// * `b`: output. + /// * `k`: the power to which to map each coefficients. + /// * `limbs_a`: the number of limbs_a on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `limbs_a` is greater than `a.limbs()`. + /// + /// # Example + /// ``` + /// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; + /// use itertools::izip; + /// + /// let n: usize = 8; // polynomial degree + /// let module = Module::new::(n); + /// let mut a: VecZnx = module.new_vec_znx(2); + /// let mut b: VecZnx = module.new_vec_znx(2); + /// let mut c: VecZnx = module.new_vec_znx(2); + /// + /// (0..a.limbs()).for_each(|i|{ + /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ + /// *x = i as i64 + /// }) + /// }); + /// + /// module.vec_znx_automorphism(-1, &mut b, &a, 1); // X^i -> X^(-i) + /// let limb = c.at_mut(0); + /// (1..limb.len()).for_each(|i|{ + /// limb[n-i] = -(i as i64) + /// }); + /// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// ``` + fn vec_znx_automorphism(&self, k: i64, b: &mut T, a: &T, limbs_a: usize) { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert!(a.limbs() >= limbs_a); unsafe { vec_znx::vec_znx_automorphism( self.0, @@ -582,13 +587,55 @@ impl VecZnxOps for Module { b.limbs() as u64, b.n() as u64, a.as_ptr(), - a.limbs() as u64, + limbs_a as u64, a.n() as u64, ); } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each limbs. + /// + /// # Arguments + /// + /// * `a`: input and output. + /// * `k`: the power to which to map each coefficients. + /// * `limbs_a`: the number of limbs on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `limbs` is greater than `self.limbs()`. + /// + /// # Example + /// ``` + /// use base2k::{Module, FFT64, VecZnx, Encoding, Infos, VecZnxApi, VecZnxOps}; + /// use itertools::izip; + /// + /// let n: usize = 8; // polynomial degree + /// let module = Module::new::(n); + /// let mut a: VecZnx = VecZnx::new(n, 2); + /// let mut b: VecZnx = VecZnx::new(n, 2); + /// + /// (0..a.limbs()).for_each(|i|{ + /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ + /// *x = i as i64 + /// }) + /// }); + /// + /// module.vec_znx_automorphism_inplace(-1, &mut a, 1); // X^i -> X^(-i) + /// let limb = b.at_mut(0); + /// (1..limb.len()).for_each(|i|{ + /// limb[n-i] = -(i as i64) + /// }); + /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + /// ``` + fn vec_znx_automorphism_inplace( + &self, + k: i64, + a: &mut T, + limbs_a: usize, + ) { + assert_eq!(a.n(), self.n()); + assert!(a.limbs() >= limbs_a); unsafe { vec_znx::vec_znx_automorphism( self.0, @@ -597,7 +644,7 @@ impl VecZnxOps for Module { a.limbs() as u64, a.n() as u64, a.as_ptr(), - a.limbs() as u64, + limbs_a as u64, a.n() as u64, ); } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 83ed428..fa36e6f 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; -use crate::{Infos, Module, VecZnx, VecZnxDft}; +use crate::{Infos, Module, VecZnxApi, VecZnxDft}; pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); @@ -23,11 +23,9 @@ impl VecZnxBig { } } -impl Module { - // Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - pub fn new_vec_znx_big(&self, limbs: usize) -> VecZnxBig { - unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } - } +pub trait VecZnxBigOps { + /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. + fn new_vec_znx_big(&self, limbs: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -38,24 +36,78 @@ impl Module { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - pub fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { + fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. + fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize; + + /// b <- b - a + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &T); + + /// c <- b - a + fn vec_znx_big_sub_small_a( + &self, + c: &mut VecZnxBig, + a: &T, + b: &VecZnxBig, + ); + + /// c <- b + a + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig); + + /// b <- b + a + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &T); + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; + + /// b <- normalize(a) + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + b: &mut T, + a: &VecZnxBig, + tmp_bytes: &mut [u8], + ); + + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; + + fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + res: &mut T, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ); + + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); + + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); +} + +impl VecZnxBigOps for Module { + fn new_vec_znx_big(&self, limbs: usize) -> VecZnxBig { + unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } + } + + fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { assert!( - bytes.len() >= self.bytes_of_vec_znx_big(limbs), + bytes.len() >= ::bytes_of_vec_znx_big(self, limbs), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", bytes.len(), - self.bytes_of_vec_znx_big(limbs) + ::bytes_of_vec_znx_big(self, limbs) ); VecZnxBig::from_bytes(limbs, bytes) } - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize { + fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize { unsafe { vec_znx_big::bytes_of_vec_znx_big(self.0, limbs as u64) as usize } } - // b <- b - a - pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &T) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.0, @@ -70,8 +122,12 @@ impl Module { } } - // c <- b - a - pub fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_sub_small_a( + &self, + c: &mut VecZnxBig, + a: &T, + b: &VecZnxBig, + ) { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.0, @@ -86,8 +142,7 @@ impl Module { } } - // c <- b + a - pub fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { + fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &T, b: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_add_small( self.0, @@ -102,8 +157,7 @@ impl Module { } } - // b <- b + a - pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { + fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &T) { unsafe { vec_znx_big::vec_znx_big_add_small( self.0, @@ -118,23 +172,22 @@ impl Module { } } - pub fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } } - // b <- normalize(a) - pub fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - b: &mut VecZnx, + b: &mut T, a: &VecZnxBig, tmp_bytes: &mut [u8], ) { assert!( - tmp_bytes.len() >= self.vec_znx_big_normalize_tmp_bytes(), + tmp_bytes.len() >= ::vec_znx_big_normalize_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", tmp_bytes.len(), - self.vec_znx_big_normalize_tmp_bytes() + ::vec_znx_big_normalize_tmp_bytes(self) ); unsafe { vec_znx_big::vec_znx_big_normalize_base2k( @@ -150,14 +203,14 @@ impl Module { } } - pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { + fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize } } - pub fn vec_znx_big_range_normalize_base2k( + fn vec_znx_big_range_normalize_base2k( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut T, a: &VecZnxBig, a_range_begin: usize, a_range_xend: usize, @@ -165,10 +218,10 @@ impl Module { tmp_bytes: &mut [u8], ) { assert!( - tmp_bytes.len() >= self.vec_znx_big_range_normalize_base2k_tmp_bytes(), + tmp_bytes.len() >= ::vec_znx_big_range_normalize_base2k_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", tmp_bytes.len(), - self.vec_znx_big_range_normalize_base2k_tmp_bytes() + ::vec_znx_big_range_normalize_base2k_tmp_bytes(self) ); unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k( @@ -186,7 +239,7 @@ impl Module { } } - pub fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.0, @@ -199,7 +252,7 @@ impl Module { } } - pub fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( self.0, diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 38db9e5..6ae2dca 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; -use crate::{Module, VecZnx, VecZnxBig}; +use crate::{Infos, Module, VecZnx, VecZnxApi, VecZnxBig}; pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); @@ -24,11 +24,9 @@ impl VecZnxDft { } } -impl Module { - // Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - pub fn new_vec_znx_dft(&self, limbs: usize) -> VecZnxDft { - unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } - } +pub trait VecZnxDftOps { + /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. + fn new_vec_znx_dft(&self, limbs: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -39,24 +37,57 @@ impl Module { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - pub fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft { + fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `limbs`: the number of limbs of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. + fn vec_znx_idft_tmp_bytes(&self) -> usize; + + /// b <- IDFT(a), uses a as scratch space. + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize); + + fn vec_znx_idft( + &self, + b: &mut VecZnxBig, + a: &mut VecZnxDft, + a_limbs: usize, + tmp_bytes: &mut [u8], + ); + + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &T, a_limbs: usize); +} + +impl VecZnxDftOps for Module { + fn new_vec_znx_dft(&self, limbs: usize) -> VecZnxDft { + unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } + } + + fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft { assert!( - bytes.len() >= self.bytes_of_vec_znx_dft(limbs), + bytes.len() >= ::bytes_of_vec_znx_dft(self, limbs), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", bytes.len(), - self.bytes_of_vec_znx_dft(limbs) + ::bytes_of_vec_znx_dft(self, limbs) ); VecZnxDft::from_bytes(limbs, bytes) } - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. - pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize { + fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize { unsafe { bytes_of_vec_znx_dft(self.0, limbs as u64) as usize } } - // b <- IDFT(a), uses a as scratch space. - pub fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { assert!( b.limbs() >= a_limbs, "invalid c_vector: b_vector.limbs()={} < a_limbs={}", @@ -68,8 +99,7 @@ impl Module { } } - // Returns the size of the scratch space for [vec_znx_idft]. - pub fn vec_znx_idft_tmp_bytes(&self) -> usize { + fn vec_znx_idft_tmp_bytes(&self) -> usize { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } } @@ -77,7 +107,7 @@ impl Module { /// /// # Panics /// If b.limbs < a_limbs - pub fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize) { + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &T, a_limbs: usize) { assert!( b.limbs() >= a_limbs, "invalid a_limbs: b.limbs()={} < a_limbs={}", @@ -91,13 +121,13 @@ impl Module { b.limbs() as u64, a.as_ptr(), a_limbs as u64, - a.n as u64, + a.n() as u64, ) } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - pub fn vec_znx_idft( + fn vec_znx_idft( &self, b: &mut VecZnxBig, a: &mut VecZnxDft, @@ -117,10 +147,10 @@ impl Module { a_limbs ); assert!( - tmp_bytes.len() <= self.vec_znx_idft_tmp_bytes(), + tmp_bytes.len() <= ::vec_znx_idft_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", tmp_bytes.len(), - self.vec_znx_idft_tmp_bytes() + ::vec_znx_idft_tmp_bytes(self) ); unsafe { vec_znx_dft::vec_znx_idft( diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 08b6e85..b240817 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,6 +1,5 @@ use crate::ffi::vmp; -use crate::{Infos, Module, VecZnx, VecZnxDft}; -use std::cmp::min; +use crate::{Infos, Module, VecZnx, VecZnxApi, VecZnxDft}; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -110,7 +109,7 @@ pub trait VmpPMatOps { /// /// # Example /// ``` - /// use base2k::{Module, Matrix3D, VmpPMat, VmpPMatOps, FFT64, Free}; + /// use base2k::{Module, VmpPMat, VmpPMatOps, FFT64, Free}; /// use std::cmp::min; /// /// let n: usize = 1024; @@ -118,17 +117,12 @@ pub trait VmpPMatOps { /// let rows = 5; /// let cols = 6; /// - /// let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); - /// - /// // Populates the i-th row of b_math with X^1 * 2^(i * log_w) (here log_w is undefined) - /// (0..min(rows, cols)).for_each(|i| { - /// b_mat.at_mut(i, i)[1] = 1 as i64; - /// }); + /// let mut b_mat: Vec = vec![0i64;n * cols * rows]; /// /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; /// /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + /// module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat, &mut buf); /// /// vmp_pmat.free() // don't forget to free the memory once vmp_pmat is not needed anymore. /// ``` @@ -146,7 +140,7 @@ pub trait VmpPMatOps { /// /// # Example /// ``` - /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; + /// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free}; /// use std::cmp::min; /// /// let n: usize = 1024; @@ -159,17 +153,15 @@ pub trait VmpPMatOps { /// vecznx.push(module.new_vec_znx(cols)); /// }); /// - /// let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect(); - /// /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; /// /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf); + /// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); /// /// vmp_pmat.free(); /// module.free(); /// ``` - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); + fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. /// @@ -183,7 +175,7 @@ pub trait VmpPMatOps { /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// /// # Example /// ``` - /// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; + /// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free}; /// use std::cmp::min; /// /// let n: usize = 1024; @@ -191,7 +183,7 @@ pub trait VmpPMatOps { /// let rows: usize = 5; /// let cols: usize = 6; /// - /// let vecznx = vec![0i64; cols*n]; + /// let vecznx = module.new_vec_znx(cols); /// /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; /// @@ -201,7 +193,13 @@ pub trait VmpPMatOps { /// vmp_pmat.free(); /// module.free(); /// ``` - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row( + &self, + b: &mut VmpPMat, + a: &T, + row_i: usize, + tmp_bytes: &mut [u8], + ); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// @@ -246,7 +244,7 @@ pub trait VmpPMatOps { /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; + /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi}; /// /// let n = 1024; /// @@ -270,7 +268,13 @@ pub trait VmpPMatOps { /// vmp_pmat.free(); /// module.free(); /// ``` - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); + fn vmp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &T, + b: &VmpPMat, + buf: &mut [u8], + ); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// @@ -316,7 +320,7 @@ pub trait VmpPMatOps { /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; + /// use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, FFT64, Free}; /// /// let n = 1024; /// @@ -370,7 +374,7 @@ pub trait VmpPMatOps { /// /// # Example /// ``` - /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free}; + /// use base2k::{Module, VecZnx, VecZnxOps, VecZnxDft, VmpPMat, VmpPMatOps, FFT64, Free, VecZnxApi, VecZnxDftOps}; /// /// let n = 1024; /// @@ -424,7 +428,12 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) { + fn vmp_prepare_dblptr( + &self, + b: &mut VmpPMat, + a: &Vec, + buf: &mut [u8], + ) { let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); unsafe { vmp::vmp_prepare_dblptr( @@ -438,7 +447,13 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) { + fn vmp_prepare_row( + &self, + b: &mut VmpPMat, + a: &T, + row_i: usize, + buf: &mut [u8], + ) { unsafe { vmp::vmp_prepare_row( self.0, @@ -470,7 +485,13 @@ impl VmpPMatOps for Module { } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) { + fn vmp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &T, + b: &VmpPMat, + buf: &mut [u8], + ) { unsafe { vmp::vmp_apply_dft( self.0, @@ -537,135 +558,3 @@ impl VmpPMatOps for Module { } } } - -/// A helper struture that stores a 3D matrix as a contiguous array. -/// To be passed to [VmpPMatOps::vmp_prepare_contiguous]. -/// -/// rows: index of the i-th base2K power. -/// cols: index of the j-th limb of the i-th row. -/// n : polynomial degree. -/// -/// A [Matrix3D] can be seen as a vector of [VecZnx]. -pub struct Matrix3D { - pub data: Vec, - pub rows: usize, - pub cols: usize, - pub n: usize, -} - -impl Matrix3D { - /// Allocates a new [Matrix3D] with the respective dimensions. - /// - /// # Arguments - /// - /// * `rows`: the number of rows of the matrix. - /// * `cols`: the number of cols of the matrix. - /// # `n`: the size of each entry of the matrix. - /// - /// # Example - /// ``` - /// use base2k::Matrix3D; - /// - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// ``` - pub fn new(rows: usize, cols: usize, n: usize) -> Self { - let size = rows * cols * n; - Self { - data: vec![T::default(); size], - rows, - cols, - n, - } - } - - /// Returns a non-mutable reference to the entry (row, col) of the [Matrix3D]. - /// The returned array is of size n. - /// - /// # Arguments - /// - /// * `row`: the index of the row. - /// * `col`: the index of the col. - /// - /// # Example - /// ``` - /// use base2k::Matrix3D; - /// - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let elem: &[i64] = mat.at(4, 4); // size n - /// ``` - pub fn at(&self, row: usize, col: usize) -> &[T] { - assert!(row < self.rows && col < self.cols); - let idx: usize = row * (self.n * self.cols) + col * self.n; - &self.data[idx..idx + self.n] - } - - /// Returns a mutable reference of the array at the (row, col) entry of the [Matrix3D]. - /// The returned array is of size n. - /// - /// # Arguments - /// - /// * `row`: the index of the row. - /// * `col`: the index of the col. - /// - /// # Example - /// ``` - /// use base2k::Matrix3D; - /// - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let elem: &mut [i64] = mat.at_mut(4, 4); // size n - /// ``` - pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { - assert!(row < self.rows && col < self.cols); - let idx: usize = row * (self.n * self.cols) + col * self.n; - &mut self.data[idx..idx + self.n] - } - - /// Sets the entry \[row\] of the [Matrix3D]. - /// Typicall this is used to assign a [VecZnx] to the i-th row - /// of the [Matrix3D]. - /// - /// # Arguments - /// - /// * `row`: the index of the row. - /// * `a`: the data to encode onthe row. - /// - /// # Example - /// ``` - /// use base2k::{Matrix3D, VecZnx}; - /// - /// let rows = 5; // #decomp - /// let cols = 5; // #limbs - /// let n = 1024; // #coeffs - /// - /// let mut mat = Matrix3D::::new(rows, cols, n); - /// - /// let a: VecZnx = VecZnx::new(n, cols); - /// - /// mat.set_row(1, &a.data); - /// ``` - pub fn set_row(&mut self, row: usize, a: &[T]) { - assert!( - row < self.rows, - "invalid argument row: row={} > self.rows={}", - row, - self.rows - ); - let idx: usize = row * (self.n * self.cols); - let size: usize = min(a.len(), self.cols * self.n); - self.data[idx..idx + size].copy_from_slice(&a[..size]); - } -}