mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -54,11 +54,9 @@ pub trait GetZnxBase {
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase;
|
||||
}
|
||||
|
||||
pub trait ZnxInfos: GetZnxBase {
|
||||
pub trait ZnxInfos {
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize {
|
||||
self.znx().n
|
||||
}
|
||||
fn n(&self) -> usize;
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize {
|
||||
@@ -66,41 +64,27 @@ pub trait ZnxInfos: GetZnxBase {
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize {
|
||||
self.znx().rows
|
||||
}
|
||||
fn rows(&self) -> usize;
|
||||
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize {
|
||||
self.znx().cols
|
||||
}
|
||||
fn cols(&self) -> usize;
|
||||
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize {
|
||||
self.znx().size
|
||||
}
|
||||
|
||||
/// Returns the underlying raw bytes array.
|
||||
fn data(&self) -> &[u8] {
|
||||
&self.znx().data
|
||||
}
|
||||
|
||||
/// Returns a pointer to the underlying raw bytes array.
|
||||
fn ptr(&self) -> *mut u8 {
|
||||
self.znx().ptr
|
||||
}
|
||||
fn size(&self) -> usize;
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxSliceSize {
|
||||
/// Returns the slice size, which is the offset between
|
||||
/// two size of the same column.
|
||||
fn sl(&self) -> usize;
|
||||
}
|
||||
|
||||
// pub trait ZnxSliceSize {}
|
||||
|
||||
//(Jay) TODO: Remove ZnxAlloc
|
||||
pub trait ZnxAlloc<B: Backend>
|
||||
where
|
||||
Self: Sized + ZnxInfos,
|
||||
@@ -122,22 +106,21 @@ where
|
||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait ZnxLayout: ZnxInfos {
|
||||
type Scalar;
|
||||
pub trait DataView {
|
||||
type D;
|
||||
fn data(&self) -> &Self::D;
|
||||
}
|
||||
|
||||
/// Returns true if the receiver is only borrowing the data.
|
||||
fn borrowing(&self) -> bool {
|
||||
self.znx().data.len() == 0
|
||||
}
|
||||
pub trait DataViewMut: DataView {
|
||||
fn data_mut(&self) -> &mut Self::D;
|
||||
}
|
||||
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
type Scalar;
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.znx().ptr as *const Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.znx_mut().ptr as *mut Self::Scalar
|
||||
self.data().as_ref().as_ptr() as *const Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the entire underlying coefficient array.
|
||||
@@ -145,11 +128,6 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the entire underlying coefficient array.
|
||||
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -161,6 +139,23 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
}
|
||||
|
||||
/// Returns non-mutable reference to the (i, j)-th small polynomial.
|
||||
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the entire underlying coefficient array.
|
||||
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -172,17 +167,15 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
}
|
||||
|
||||
/// Returns non-mutable reference to the (i, j)-th small polynomial.
|
||||
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
||||
}
|
||||
|
||||
/// Returns mutable reference to the (i, j)-th small polynomial.
|
||||
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
||||
}
|
||||
}
|
||||
|
||||
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||
|
||||
use std::convert::TryFrom;
|
||||
use std::num::TryFromIntError;
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
@@ -213,7 +206,7 @@ impl IntegerType for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
pub trait ZnxZero: ZnxLayout
|
||||
pub trait ZnxZero: ZnxViewMut
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
@@ -238,16 +231,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxRsh: ZnxLayout + ZnxZero
|
||||
where
|
||||
Self: Sized,
|
||||
Self::Scalar: IntegerType,
|
||||
{
|
||||
pub trait ZnxRsh: ZnxZero {
|
||||
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
rsh(k, log_base2k, self, col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementations
|
||||
impl<T> ZnxZero for T where T: ZnxViewMut {}
|
||||
impl<T> ZnxRsh for T where T: ZnxZero {}
|
||||
|
||||
pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
|
||||
where
|
||||
V::Scalar: IntegerType,
|
||||
@@ -310,10 +303,7 @@ pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
|
||||
n * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
pub fn switch_degree<T: ZnxLayout + ZnxZero>(b: &mut T, col_b: usize, a: &T, col_a: usize)
|
||||
where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
pub fn switch_degree<DMut: ZnxViewMut + ZnxZero, D: ZnxView>(b: &mut DMut, col_b: usize, a: &D, col_a: usize) {
|
||||
let (n_in, n_out) = (a.n(), b.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
@@ -334,3 +324,64 @@ where
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
// pub trait ZnxLayout: ZnxInfos {
|
||||
// type Scalar;
|
||||
|
||||
// /// Returns true if the receiver is only borrowing the data.
|
||||
// fn borrowing(&self) -> bool {
|
||||
// self.znx().data.len() == 0
|
||||
// }
|
||||
|
||||
// /// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
// fn as_ptr(&self) -> *const Self::Scalar {
|
||||
// self.znx().ptr as *const Self::Scalar
|
||||
// }
|
||||
|
||||
// /// Returns a mutable pointer to the underlying coefficients array.
|
||||
// fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
// self.znx_mut().ptr as *mut Self::Scalar
|
||||
// }
|
||||
|
||||
// /// Returns a non-mutable reference to the entire underlying coefficient array.
|
||||
// fn raw(&self) -> &[Self::Scalar] {
|
||||
// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
||||
// }
|
||||
|
||||
// /// Returns a mutable reference to the entire underlying coefficient array.
|
||||
// fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
||||
// }
|
||||
|
||||
// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
// #[cfg(debug_assertions)]
|
||||
// {
|
||||
// assert!(i < self.cols());
|
||||
// assert!(j < self.size());
|
||||
// }
|
||||
// let offset: usize = self.n() * (j * self.cols() + i);
|
||||
// unsafe { self.as_ptr().add(offset) }
|
||||
// }
|
||||
|
||||
// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
// #[cfg(debug_assertions)]
|
||||
// {
|
||||
// assert!(i < self.cols());
|
||||
// assert!(j < self.size());
|
||||
// }
|
||||
// let offset: usize = self.n() * (j * self.cols() + i);
|
||||
// unsafe { self.as_mut_ptr().add(offset) }
|
||||
// }
|
||||
|
||||
// /// Returns non-mutable reference to the (i, j)-th small polynomial.
|
||||
// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
||||
// }
|
||||
|
||||
// /// Returns mutable reference to the (i, j)-th small polynomial.
|
||||
// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||
// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
||||
// }
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user