mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
everything compiles. Scratchpad not yet implemented
This commit is contained in:
@@ -85,26 +85,26 @@ pub trait ZnxInfos {
|
||||
// pub trait ZnxSliceSize {}
|
||||
|
||||
//(Jay) TODO: Remove ZnxAlloc
|
||||
pub trait ZnxAlloc<B: Backend>
|
||||
where
|
||||
Self: Sized + ZnxInfos,
|
||||
{
|
||||
type Scalar;
|
||||
fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
|
||||
let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
|
||||
Self::from_bytes(module, rows, cols, size, bytes)
|
||||
}
|
||||
// pub trait ZnxAlloc<B: Backend>
|
||||
// where
|
||||
// Self: Sized + ZnxInfos,
|
||||
// {
|
||||
// type Scalar;
|
||||
// fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
|
||||
// let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
|
||||
// Self::from_bytes(module, rows, cols, size, bytes)
|
||||
// }
|
||||
|
||||
fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
|
||||
res.znx_mut().data = bytes;
|
||||
res
|
||||
}
|
||||
// fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
|
||||
// res.znx_mut().data = bytes;
|
||||
// res
|
||||
// }
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
// fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
|
||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
// fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
||||
// }
|
||||
|
||||
pub trait DataView {
|
||||
type D;
|
||||
@@ -112,11 +112,11 @@ pub trait DataView {
|
||||
}
|
||||
|
||||
pub trait DataViewMut: DataView {
|
||||
fn data_mut(&self) -> &mut Self::D;
|
||||
fn data_mut(&mut self) -> &mut Self::D;
|
||||
}
|
||||
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
type Scalar;
|
||||
type Scalar: Copy;
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
@@ -177,11 +177,9 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||
|
||||
use std::convert::TryFrom;
|
||||
use std::num::TryFromIntError;
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
pub trait IntegerType:
|
||||
pub trait Num:
|
||||
Copy
|
||||
+ std::fmt::Debug
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ PartialOrd
|
||||
@@ -190,22 +188,23 @@ pub trait IntegerType:
|
||||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Neg<Output = Self>
|
||||
+ Shr<Output = Self>
|
||||
+ Shl<Output = Self>
|
||||
+ AddAssign
|
||||
+ TryFrom<usize, Error = TryFromIntError>
|
||||
{
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl IntegerType for i64 {
|
||||
impl Num for i64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
impl IntegerType for i128 {
|
||||
impl Num for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
impl Num for f64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
pub trait ZnxZero: ZnxViewMut
|
||||
where
|
||||
Self: Sized,
|
||||
@@ -231,79 +230,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxRsh: ZnxZero {
|
||||
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
rsh(k, log_base2k, self, col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementations
|
||||
impl<T> ZnxZero for T where T: ZnxViewMut {}
|
||||
impl<T> ZnxRsh for T where T: ZnxZero {}
|
||||
// impl<T> ZnxRsh for T where T: ZnxZero {}
|
||||
|
||||
pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
|
||||
where
|
||||
V::Scalar: IntegerType,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
||||
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||
tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||
n,
|
||||
size,
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / log_base2k;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
}
|
||||
|
||||
let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
|
||||
let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
|
||||
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
||||
|
||||
(steps..size).for_each(|i| {
|
||||
izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << log_base2k_t;
|
||||
*ci = get_base_k_carry(*xi, shift);
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
|
||||
(x << shift) >> shift
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
|
||||
n * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
pub fn switch_degree<DMut: ZnxViewMut + ZnxZero, D: ZnxView>(b: &mut DMut, col_b: usize, a: &D, col_a: usize) {
|
||||
pub fn switch_degree<S: Copy, DMut: ZnxViewMut<Scalar = S> + ZnxZero, D: ZnxView<Scalar = S>>(
|
||||
b: &mut DMut,
|
||||
col_b: usize,
|
||||
a: &D,
|
||||
col_a: usize,
|
||||
) {
|
||||
let (n_in, n_out) = (a.n(), b.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
@@ -325,6 +261,71 @@ pub fn switch_degree<DMut: ZnxViewMut + ZnxZero, D: ZnxView>(b: &mut DMut, col_b
|
||||
});
|
||||
}
|
||||
|
||||
// (Jay)TODO: implement rsh for VecZnx, VecZnxBig
|
||||
// pub trait ZnxRsh: ZnxZero {
|
||||
// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
// rsh(k, log_base2k, self, col, carry)
|
||||
// }
|
||||
// }
|
||||
// pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
// let n: usize = a.n();
|
||||
// let size: usize = a.size();
|
||||
// let cols: usize = a.cols();
|
||||
|
||||
// #[cfg(debug_assertions)]
|
||||
// {
|
||||
// assert!(
|
||||
// tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
||||
// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||
// tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||
// n,
|
||||
// size,
|
||||
// );
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
// }
|
||||
|
||||
// let size: usize = a.size();
|
||||
// let steps: usize = k / log_base2k;
|
||||
|
||||
// a.raw_mut().rotate_right(n * steps * cols);
|
||||
// (0..cols).for_each(|i| {
|
||||
// (0..steps).for_each(|j| {
|
||||
// a.zero_at(i, j);
|
||||
// })
|
||||
// });
|
||||
|
||||
// let k_rem: usize = k % log_base2k;
|
||||
|
||||
// if k_rem != 0 {
|
||||
// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
|
||||
|
||||
// unsafe {
|
||||
// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
// }
|
||||
|
||||
// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
|
||||
// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
|
||||
// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
||||
|
||||
// (steps..size).for_each(|i| {
|
||||
// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
|
||||
// *xi += *ci << log_base2k_t;
|
||||
// *ci = get_base_k_carry(*xi, shift);
|
||||
// *xi = (*xi - *ci) >> k_rem_t;
|
||||
// });
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[inline(always)]
|
||||
// fn get_base_k_carry<T: Num>(x: T, shift: T) -> T {
|
||||
// (x << shift) >> shift
|
||||
// }
|
||||
|
||||
// pub fn rsh_tmp_bytes<T: Num>(n: usize) -> usize {
|
||||
// n * std::mem::size_of::<T>()
|
||||
// }
|
||||
|
||||
// pub trait ZnxLayout: ZnxInfos {
|
||||
// type Scalar;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user