updated base2k backend

This commit is contained in:
Jean-Philippe Bossuat
2025-02-14 10:58:28 +01:00
parent 4aeaf99fe2
commit 68e61dc0e3
7 changed files with 137 additions and 22 deletions

View File

@@ -40,8 +40,10 @@ fn main() {
vecznx[i].data[i * n + 1] = 1 as i64;
});
let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf);
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);

View File

@@ -73,3 +73,51 @@ pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] {
let len: usize = data.len() / std::mem::size_of::<f64>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
use std::alloc::{alloc, Layout};
pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec<u8> {
assert_eq!(
align & (align - 1),
0,
"align={} must be a power of two",
align
);
assert_eq!(
(size * std::mem::size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
unsafe {
let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment");
let ptr: *mut u8 = alloc(layout);
if ptr.is_null() {
panic!("Memory allocation failed");
}
Vec::from_raw_parts(ptr, size, size)
}
}
pub fn alloc_aligned<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * std::mem::size_of::<T>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_u8(std::mem::size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
fn alias_mut_slice_to_vec<T>(slice: &mut [T]) -> Vec<T> {
let ptr = slice.as_mut_ptr();
let len = slice.len();
unsafe { Vec::from_raw_parts(ptr, len, len) }
}

View File

@@ -1,7 +1,7 @@
use crate::ffi::svp;
use crate::{Module, VecZnx, VecZnxDft};
use crate::{alias_mut_slice_to_vec, Module, VecZnx, VecZnxDft};
use crate::Infos;
use crate::{alloc_aligned, cast_mut_u8_to_mut_i64_slice, Infos};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, WeightedIndex};
@@ -17,14 +17,18 @@ impl Module {
impl Scalar {
pub fn new(n: usize) -> Self {
Self(vec![i64::default(); Self::buffer_size(n)])
Self(alloc_aligned::<i64>(n, 64))
}
pub fn n(&self) -> usize {
self.0.len()
}
pub fn buffer_size(n: usize) -> usize {
n
}
pub fn from_buffer(&mut self, n: usize, buf: &[i64]) {
pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) {
let size: usize = Self::buffer_size(n);
assert!(
buf.len() >= size,
@@ -33,7 +37,7 @@ impl Scalar {
n,
size
);
self.0 = Vec::from(&buf[..size])
self.0 = alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size]))
}
pub fn as_ptr(&self) -> *const i64 {
@@ -50,6 +54,7 @@ impl Scalar {
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.0[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);

View File

@@ -1,6 +1,7 @@
use crate::cast_mut_u8_to_mut_i64_slice;
use crate::ffi::vec_znx;
use crate::ffi::znx;
use crate::{alias_mut_slice_to_vec, alloc_aligned};
use crate::{Infos, Module};
use itertools::izip;
use std::cmp::min;
@@ -21,7 +22,7 @@ impl VecZnx {
pub fn new(n: usize, limbs: usize) -> Self {
Self {
n: n,
data: vec![i64::default(); n * limbs],
data: alloc_aligned::<i64>(n * limbs, 64),
}
}
@@ -47,7 +48,7 @@ impl VecZnx {
VecZnx {
n: n,
data: Vec::from(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])),
data: alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])),
}
}
@@ -106,7 +107,7 @@ impl VecZnx {
///
/// # Example
/// ```
/// use base2k::{VecZnx, Encoding, Infos};
/// use base2k::{VecZnx, Encoding, Infos, alloc_aligned};
/// use itertools::izip;
/// use sampling::source::Source;
///
@@ -115,8 +116,8 @@ impl VecZnx {
/// let limbs: usize = 5; // number of limbs (i.e. can store coeffs in the range +/- 2^{limbs * log_base2k - 1})
/// let log_k: usize = limbs * log_base2k - 5;
/// let mut a: VecZnx = VecZnx::new(n, limbs);
/// let mut carry: Vec<u8> = vec![u8::default(); a.n()<<3];
/// let mut have: Vec<i64> = vec![i64::default(); a.n()];
/// let mut carry: Vec<u8> = alloc_aligned::<u8>(a.n()<<3, 64);
/// let mut have: Vec<i64> = alloc_aligned::<i64>(a.n(), 64);
/// let mut source = Source::new([1; 32]);
///
/// // Populates the first limb of the of polynomials with random i64 values.
@@ -135,7 +136,7 @@ impl VecZnx {
/// .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half));
///
/// // Ensures reconstructed normalized values are equal to non-normalized values.
/// let mut want = vec![i64::default(); n];
/// let mut want = alloc_aligned::<i64>(a.n(), 64);
/// a.decode_vec_i64(log_base2k, log_k, &mut want);
/// izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ```

View File

@@ -29,6 +29,25 @@ impl Module {
unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) }
}
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `limbs`: the number of limbs of the [VecZnxBig].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
pub fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig {
assert!(
bytes.len() >= self.bytes_of_vec_znx_big(limbs),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
self.bytes_of_vec_znx_big(limbs)
);
VecZnxBig::from_bytes(limbs, bytes)
}
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize {
@@ -131,6 +150,42 @@ impl Module {
}
}
pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize }
}
pub fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
res: &mut VecZnx,
a: &VecZnxBig,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
) {
assert!(
tmp_bytes.len() >= self.vec_znx_big_range_normalize_base2k_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
tmp_bytes.len(),
self.vec_znx_big_range_normalize_base2k_tmp_bytes()
);
unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k(
self.0,
log_base2k as u64,
res.as_mut_ptr(),
res.limbs() as u64,
res.n() as u64,
a.0,
a_range_begin as u64,
a_range_xend as u64,
a_range_step as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
pub fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(

View File

@@ -39,7 +39,7 @@ impl Module {
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
pub fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
assert!(
bytes.len() >= self.bytes_of_vec_znx_dft(limbs),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
@@ -63,7 +63,9 @@ impl Module {
b.limbs(),
a_limbs
);
unsafe { vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64) }
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64)
}
}
// Returns the size of the scratch space for [vec_znx_idft].

View File

@@ -159,15 +159,17 @@ pub trait VmpPMatOps {
/// vecznx.push(module.new_vec_znx(cols));
/// });
///
/// let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]);
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
///
@@ -189,7 +191,7 @@ pub trait VmpPMatOps {
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let vecznx = module.new_vec_znx(cols);
/// let vecznx = vec![0i64; cols*n];
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
@@ -199,7 +201,7 @@ pub trait VmpPMatOps {
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]);
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
@@ -422,8 +424,8 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect();
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
unsafe {
vmp::vmp_prepare_dblptr(
self.0,
@@ -436,12 +438,12 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) {
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
unsafe {
vmp::vmp_prepare_row(
self.0,
b.data(),
a.data.as_ptr(),
a.as_ptr(),
row_i as u64,
b.rows() as u64,
b.cols() as u64,