mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
updated base2k backend
This commit is contained in:
@@ -40,8 +40,10 @@ fn main() {
|
||||
vecznx[i].data[i * n + 1] = 1 as i64;
|
||||
});
|
||||
|
||||
let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
|
||||
|
||||
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
|
||||
module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf);
|
||||
|
||||
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
||||
|
||||
@@ -73,3 +73,51 @@ pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] {
|
||||
let len: usize = data.len() / std::mem::size_of::<f64>();
|
||||
unsafe { std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
use std::alloc::{alloc, Layout};
|
||||
|
||||
pub fn alloc_aligned_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
assert_eq!(
|
||||
align & (align - 1),
|
||||
0,
|
||||
"align={} must be a power of two",
|
||||
align
|
||||
);
|
||||
assert_eq!(
|
||||
(size * std::mem::size_of::<u8>()) % align,
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
align
|
||||
);
|
||||
unsafe {
|
||||
let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment");
|
||||
let ptr: *mut u8 = alloc(layout);
|
||||
if ptr.is_null() {
|
||||
panic!("Memory allocation failed");
|
||||
}
|
||||
Vec::from_raw_parts(ptr, size, size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc_aligned<T>(size: usize, align: usize) -> Vec<T> {
|
||||
assert_eq!(
|
||||
(size * std::mem::size_of::<T>()) % align,
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
align
|
||||
);
|
||||
let mut vec_u8: Vec<u8> = alloc_aligned_u8(std::mem::size_of::<T>() * size, align);
|
||||
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
|
||||
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
|
||||
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
|
||||
std::mem::forget(vec_u8);
|
||||
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
||||
}
|
||||
|
||||
fn alias_mut_slice_to_vec<T>(slice: &mut [T]) -> Vec<T> {
|
||||
let ptr = slice.as_mut_ptr();
|
||||
let len = slice.len();
|
||||
unsafe { Vec::from_raw_parts(ptr, len, len) }
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ffi::svp;
|
||||
use crate::{Module, VecZnx, VecZnxDft};
|
||||
use crate::{alias_mut_slice_to_vec, Module, VecZnx, VecZnxDft};
|
||||
|
||||
use crate::Infos;
|
||||
use crate::{alloc_aligned, cast_mut_u8_to_mut_i64_slice, Infos};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, WeightedIndex};
|
||||
@@ -17,14 +17,18 @@ impl Module {
|
||||
|
||||
impl Scalar {
|
||||
pub fn new(n: usize) -> Self {
|
||||
Self(vec![i64::default(); Self::buffer_size(n)])
|
||||
Self(alloc_aligned::<i64>(n, 64))
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
pub fn buffer_size(n: usize) -> usize {
|
||||
n
|
||||
}
|
||||
|
||||
pub fn from_buffer(&mut self, n: usize, buf: &[i64]) {
|
||||
pub fn from_buffer(&mut self, n: usize, buf: &mut [u8]) {
|
||||
let size: usize = Self::buffer_size(n);
|
||||
assert!(
|
||||
buf.len() >= size,
|
||||
@@ -33,7 +37,7 @@ impl Scalar {
|
||||
n,
|
||||
size
|
||||
);
|
||||
self.0 = Vec::from(&buf[..size])
|
||||
self.0 = alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size]))
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const i64 {
|
||||
@@ -50,6 +54,7 @@ impl Scalar {
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||
assert!(hw <= self.n());
|
||||
self.0[..hw]
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::cast_mut_u8_to_mut_i64_slice;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::ffi::znx;
|
||||
use crate::{alias_mut_slice_to_vec, alloc_aligned};
|
||||
use crate::{Infos, Module};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
@@ -21,7 +22,7 @@ impl VecZnx {
|
||||
pub fn new(n: usize, limbs: usize) -> Self {
|
||||
Self {
|
||||
n: n,
|
||||
data: vec![i64::default(); n * limbs],
|
||||
data: alloc_aligned::<i64>(n * limbs, 64),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +48,7 @@ impl VecZnx {
|
||||
|
||||
VecZnx {
|
||||
n: n,
|
||||
data: Vec::from(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])),
|
||||
data: alias_mut_slice_to_vec(cast_mut_u8_to_mut_i64_slice(&mut buf[..size])),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +107,7 @@ impl VecZnx {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{VecZnx, Encoding, Infos};
|
||||
/// use base2k::{VecZnx, Encoding, Infos, alloc_aligned};
|
||||
/// use itertools::izip;
|
||||
/// use sampling::source::Source;
|
||||
///
|
||||
@@ -115,8 +116,8 @@ impl VecZnx {
|
||||
/// let limbs: usize = 5; // number of limbs (i.e. can store coeffs in the range +/- 2^{limbs * log_base2k - 1})
|
||||
/// let log_k: usize = limbs * log_base2k - 5;
|
||||
/// let mut a: VecZnx = VecZnx::new(n, limbs);
|
||||
/// let mut carry: Vec<u8> = vec![u8::default(); a.n()<<3];
|
||||
/// let mut have: Vec<i64> = vec![i64::default(); a.n()];
|
||||
/// let mut carry: Vec<u8> = alloc_aligned::<u8>(a.n()<<3, 64);
|
||||
/// let mut have: Vec<i64> = alloc_aligned::<i64>(a.n(), 64);
|
||||
/// let mut source = Source::new([1; 32]);
|
||||
///
|
||||
/// // Populates the first limb of the of polynomials with random i64 values.
|
||||
@@ -135,7 +136,7 @@ impl VecZnx {
|
||||
/// .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half));
|
||||
///
|
||||
/// // Ensures reconstructed normalized values are equal to non-normalized values.
|
||||
/// let mut want = vec![i64::default(); n];
|
||||
/// let mut want = alloc_aligned::<i64>(a.n(), 64);
|
||||
/// a.decode_vec_i64(log_base2k, log_k, &mut want);
|
||||
/// izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
/// ```
|
||||
|
||||
@@ -29,6 +29,25 @@ impl Module {
|
||||
unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `limbs`: the number of limbs of the [VecZnxBig].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
pub fn new_vec_znx_big_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxBig {
|
||||
assert!(
|
||||
bytes.len() >= self.bytes_of_vec_znx_big(limbs),
|
||||
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
|
||||
bytes.len(),
|
||||
self.bytes_of_vec_znx_big(limbs)
|
||||
);
|
||||
VecZnxBig::from_bytes(limbs, bytes)
|
||||
}
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize {
|
||||
@@ -131,6 +150,42 @@ impl Module {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.0) as usize }
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_range_normalize_base2k(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
a: &VecZnxBig,
|
||||
a_range_begin: usize,
|
||||
a_range_xend: usize,
|
||||
a_range_step: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(
|
||||
tmp_bytes.len() >= self.vec_znx_big_range_normalize_base2k_tmp_bytes(),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
self.vec_znx_big_range_normalize_base2k_tmp_bytes()
|
||||
);
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_range_normalize_base2k(
|
||||
self.0,
|
||||
log_base2k as u64,
|
||||
res.as_mut_ptr(),
|
||||
res.limbs() as u64,
|
||||
res.n() as u64,
|
||||
a.0,
|
||||
a_range_begin as u64,
|
||||
a_range_xend as u64,
|
||||
a_range_step as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
|
||||
@@ -39,7 +39,7 @@ impl Module {
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
|
||||
pub fn new_vec_znx_dft_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
|
||||
assert!(
|
||||
bytes.len() >= self.bytes_of_vec_znx_dft(limbs),
|
||||
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
|
||||
@@ -63,7 +63,9 @@ impl Module {
|
||||
b.limbs(),
|
||||
a_limbs
|
||||
);
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64) }
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(self.0, b.0, b.limbs() as u64, a.0, a_limbs as u64)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the size of the scratch space for [vec_znx_idft].
|
||||
|
||||
@@ -159,15 +159,17 @@ pub trait VmpPMatOps {
|
||||
/// vecznx.push(module.new_vec_znx(cols));
|
||||
/// });
|
||||
///
|
||||
/// let dble: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &dble, &mut buf);
|
||||
///
|
||||
/// vmp_pmat.free();
|
||||
/// module.free();
|
||||
/// ```
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]);
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
|
||||
///
|
||||
@@ -189,7 +191,7 @@ pub trait VmpPMatOps {
|
||||
/// let rows: usize = 5;
|
||||
/// let cols: usize = 6;
|
||||
///
|
||||
/// let vecznx = module.new_vec_znx(cols);
|
||||
/// let vecznx = vec![0i64; cols*n];
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
///
|
||||
@@ -199,7 +201,7 @@ pub trait VmpPMatOps {
|
||||
/// vmp_pmat.free();
|
||||
/// module.free();
|
||||
/// ```
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]);
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
@@ -422,8 +424,8 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.data.as_ptr()).collect();
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
|
||||
unsafe {
|
||||
vmp::vmp_prepare_dblptr(
|
||||
self.0,
|
||||
@@ -436,12 +438,12 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) {
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.0,
|
||||
b.data(),
|
||||
a.data.as_ptr(),
|
||||
a.as_ptr(),
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
|
||||
Reference in New Issue
Block a user