Applied discussed changes, everything working, but still to discuss

This commit is contained in:
Jean-Philippe Bossuat
2025-05-01 10:33:19 +02:00
parent 4e6fce3458
commit ca5e6d46c9
14 changed files with 710 additions and 508 deletions

View File

@@ -22,6 +22,33 @@ pub struct ZnxBase {
pub ptr: *mut u8,
}
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
}
}
pub trait GetZnxBase {
fn znx(&self) -> &ZnxBase;
fn znx_mut(&mut self) -> &mut ZnxBase;
@@ -52,10 +79,12 @@ pub trait ZnxInfos: GetZnxBase {
self.znx().size
}
/// Returns the underlying raw bytes array.
fn data(&self) -> &[u8] {
&self.znx().data
}
/// Returns a pointer to the underlying raw bytes array.
fn ptr(&self) -> *mut u8 {
self.znx().ptr
}
@@ -72,33 +101,6 @@ pub trait ZnxSliceSize {
fn sl(&self) -> usize;
}
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
}
}
pub trait ZnxAlloc<B: Backend>
where
Self: Sized + ZnxInfos,
@@ -148,25 +150,25 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a non-mutable pointer starting at the (i, j)-th small polynomial.
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset = self.n() * (j * self.cols() + i);
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
/// Returns a mutable pointer starting at the (i, j)-th small polynomial.
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset = self.n() * (j * self.cols() + i);
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}
@@ -179,16 +181,6 @@ pub trait ZnxLayout: ZnxInfos {
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
}
/// Returns non-mutable reference to the i-th limb.
fn at_limb(&self, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) }
}
/// Returns mutable reference to the i-th limb.
fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) }
}
}
use std::convert::TryFrom;
@@ -221,14 +213,17 @@ impl IntegerType for i128 {
const BITS: u32 = 128;
}
pub trait ZnxBasics: ZnxLayout
pub trait ZnxZero: ZnxLayout
where
Self: Sized,
Self::Scalar: IntegerType,
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::<Self::Scalar>());
std::ptr::write_bytes(
self.as_mut_ptr(),
0,
self.n() * size_of::<Self::Scalar>() * self.poly_count(),
);
}
}
@@ -241,13 +236,19 @@ where
);
}
}
}
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
pub trait ZnxRsh: ZnxLayout + ZnxZero
where
Self: Sized,
Self::Scalar: IntegerType,
{
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
rsh(k, log_base2k, self, col, carry)
}
}
pub fn rsh<V: ZnxBasics>(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8])
pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
where
V::Scalar: IntegerType,
{
@@ -258,7 +259,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n, cols),
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
tmp_bytes.len() / size_of::<V::Scalar>(),
n,
@@ -291,7 +292,7 @@ where
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
(steps..size).for_each(|i| {
izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << log_base2k_t;
*ci = get_base_k_carry(*xi, shift);
*xi = (*xi - *ci) >> k_rem_t;
@@ -305,11 +306,11 @@ fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
(x << shift) >> shift
}
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize {
n * cols * std::mem::size_of::<T>()
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
n * std::mem::size_of::<T>()
}
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, col_b: usize, a: &T, col_a: usize)
pub fn switch_degree<T: ZnxLayout + ZnxZero>(b: &mut T, col_b: usize, a: &T, col_a: usize)
where
<T as ZnxLayout>::Scalar: IntegerType,
{