From 68e61dc0e36a58a12d3671fb4e4c9f534c24f6d4 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 14 Feb 2025 10:58:28 +0100 Subject: [PATCH] updated base2k backend --- base2k/examples/vector_matrix_product.rs | 4 +- base2k/src/lib.rs | 48 +++++++++++++++++++++ base2k/src/svp.rs | 15 ++++--- base2k/src/vec_znx.rs | 13 +++--- base2k/src/vec_znx_big.rs | 55 ++++++++++++++++++++++++ base2k/src/vec_znx_dft.rs | 6 ++- base2k/src/vmp.rs | 18 ++++---- 7 files changed, 137 insertions(+), 22 deletions(-) diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index fb10d3d..e4e1e29 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -40,8 +40,10 @@ fn main() { vecznx[i].data[i * n + 1] = 1 as i64; }); + let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect(); + let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); + module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf); let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index df2d6f3..769858b 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -73,3 +73,51 @@ pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] { let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } + +use std::alloc::{alloc, Layout}; + +pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec { + assert_eq!( + align & (align - 1), + 0, + "align={} must be a power of two", + align + ); + assert_eq!( + (size * std::mem::size_of::()) % align, + 0, + "size={} must be a multiple of align={}", + size, + align + ); + unsafe { + let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment"); + let ptr: *mut u8 = alloc(layout); + if ptr.is_null() { + panic!("Memory allocation failed"); + } + Vec::from_raw_parts(ptr, size, size) + } +} + +pub fn alloc_aligned(size: usize, align: usize) -> Vec { + assert_eq!( + (size * std::mem::size_of::()) % align, + 0, + "size={} must be a multiple of align={}", + size, + align + ); + let mut vec_u8: Vec = alloc_aligned_u8(std::mem::size_of::() * size, align); + let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; + let len: usize = vec_u8.len() / std::mem::size_of::(); + let cap: usize = vec_u8.capacity() / std::mem::size_of::(); + std::mem::forget(vec_u8); + unsafe { Vec::from_raw_parts(ptr, len, cap) } +} + +fn alias_mut_slice_to_vec(slice: &mut [T]) -> Vec { + let ptr = slice.as_mut_ptr(); + let len = slice.len(); + unsafe { Vec::from_raw_parts(ptr, len, len) } +} diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs index 5ee2300..ee5cc3b 100644 --- a/base2k/src/svp.rs +++ b/base2k/src/svp.rs @@ -1,7 +1,7 @@ use crate::ffi::svp; -use crate::{Module, VecZnx, VecZnxDft}; +use crate::{alias_mut_slice_to_vec, Module, VecZnx, VecZnxDft}; -use crate::Infos; +use crate::{alloc_aligned, cast_mut_u8_to_mut_i64_slice, Infos}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, WeightedIndex}; @@ -17,14 +17,18 @@ impl Module { impl Scalar { pub fn new(n: usize) -> Self { - Self(vec![i64::default(); Self::buffer_size(n)]) + Self(alloc_aligned::(n, 64)) + } + + pub fn n(&self) -> usize { + self.0.len() } pub fn buffer_size(n: usize) -> usize { n } - pub fn from_buffer(&mut self, n: usize, buf: &[i64]) { + pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) { let size: usize = Self::buffer_size(n); assert!( buf.len() >= size, @@ -33,7 +37,7 @@ impl Scalar { n, size ); - self.0 = Vec::from(&buf[..size]) + self.0 = alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])) } pub fn as_ptr(&self) -> *const i64 { @@ -50,6 +54,7 @@ impl Scalar { } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); self.0[..hw] .iter_mut() .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index f0a9421..615cb35 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,6 +1,7 @@ use crate::cast_mut_u8_to_mut_i64_slice; use crate::ffi::vec_znx; use crate::ffi::znx; +use crate::{alias_mut_slice_to_vec, alloc_aligned}; use crate::{Infos, Module}; use itertools::izip; use std::cmp::min; @@ -21,7 +22,7 @@ impl VecZnx { pub fn new(n: usize, limbs: usize) -> Self { Self { n: n, - data: vec![i64::default(); n * limbs], + data: alloc_aligned::(n * limbs, 64), } } @@ -47,7 +48,7 @@ impl VecZnx { VecZnx { n: n, - data: Vec::from(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])), + data: alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])), } } @@ -106,7 +107,7 @@ impl VecZnx { /// /// # Example /// ``` - /// use base2k::{VecZnx, Encoding, Infos}; + /// use base2k::{VecZnx, Encoding, Infos, alloc_aligned}; /// use itertools::izip; /// use sampling::source::Source; /// @@ -115,8 +116,8 @@ impl VecZnx { /// let limbs: usize = 5; // number of limbs (i.e. can store coeffs in the range +/- 2^{limbs * log_base2k - 1}) /// let log_k: usize = limbs * log_base2k - 5; /// let mut a: VecZnx = VecZnx::new(n, limbs); - /// let mut carry: Vec = vec![u8::default(); a.n()<<3]; - /// let mut have: Vec = vec![i64::default(); a.n()]; + /// let mut carry: Vec = alloc_aligned::(a.n()<<3, 64); + /// let mut have: Vec = alloc_aligned::(a.n(), 64); /// let mut source = Source::new([1; 32]); /// /// // Populates the first limb of the of polynomials with random i64 values. @@ -135,7 +136,7 @@ impl VecZnx { /// .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); /// /// // Ensures reconstructed normalized values are equal to non-normalized values. - /// let mut want = vec![i64::default(); n]; + /// let mut want = alloc_aligned::(a.n(), 64); /// a.decode_vec_i64(log_base2k, log_k, &mut want); /// izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// ``` diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 2e16d6a..83ed428 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -29,6 +29,25 @@ impl Module { unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } } + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `limbs`: the number of limbs of the [VecZnxBig]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + pub fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { + assert!( + bytes.len() >= self.bytes_of_vec_znx_big(limbs), + "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", + bytes.len(), + self.bytes_of_vec_znx_big(limbs) + ); + VecZnxBig::from_bytes(limbs, bytes) + } + /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize { @@ -131,6 +150,42 @@ impl Module { } } + pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize } + } + + pub fn vec_znx_big_range_normalize_base2k( + &self, + log_base2k: usize, + res: &mut VecZnx, + a: &VecZnxBig, + a_range_begin: usize, + a_range_xend: usize, + a_range_step: usize, + tmp_bytes: &mut [u8], + ) { + assert!( + tmp_bytes.len() >= self.vec_znx_big_range_normalize_base2k_tmp_bytes(), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", + tmp_bytes.len(), + self.vec_znx_big_range_normalize_base2k_tmp_bytes() + ); + unsafe { + vec_znx_big::vec_znx_big_range_normalize_base2k( + self.0, + log_base2k as u64, + res.as_mut_ptr(), + res.limbs() as u64, + res.n() as u64, + a.0, + a_range_begin as u64, + a_range_xend as u64, + a_range_step as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + pub fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { unsafe { vec_znx_big::vec_znx_big_automorphism( diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 85cf7c7..38db9e5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -39,7 +39,7 @@ impl Module { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft { + pub fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft { assert!( bytes.len() >= self.bytes_of_vec_znx_dft(limbs), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", @@ -63,7 +63,9 @@ impl Module { b.limbs(), a_limbs ); - unsafe { vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64) } + unsafe { + vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64) + } } // Returns the size of the scratch space for [vec_znx_idft]. diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 57947ab..08b6e85 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -159,15 +159,17 @@ pub trait VmpPMatOps { /// vecznx.push(module.new_vec_znx(cols)); /// }); /// + /// let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect(); + /// /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; /// /// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); - /// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); + /// module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf); /// /// vmp_pmat.free(); /// module.free(); /// ``` - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]); + fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. /// @@ -189,7 +191,7 @@ pub trait VmpPMatOps { /// let rows: usize = 5; /// let cols: usize = 6; /// - /// let vecznx = module.new_vec_znx(cols); + /// let vecznx = vec![0i64; cols*n]; /// /// let mut buf: Vec = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)]; /// @@ -199,7 +201,7 @@ pub trait VmpPMatOps { /// vmp_pmat.free(); /// module.free(); /// ``` - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// @@ -422,8 +424,8 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec, buf: &mut [u8]) { - let ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect(); + fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) { + let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); unsafe { vmp::vmp_prepare_dblptr( self.0, @@ -436,12 +438,12 @@ impl VmpPMatOps for Module { } } - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) { + fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) { unsafe { vmp::vmp_prepare_row( self.0, b.data(), - a.data.as_ptr(), + a.as_ptr(), row_i as u64, b.rows() as u64, b.cols() as u64,