mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
389 lines
12 KiB
Rust
389 lines
12 KiB
Rust
use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut};
|
|
use itertools::izip;
|
|
use std::cmp::min;
|
|
|
|
pub struct ZnxBase {
|
|
/// The ring degree
|
|
pub n: usize,
|
|
|
|
/// The number of rows (in the third dimension)
|
|
pub rows: usize,
|
|
|
|
/// The number of polynomials
|
|
pub cols: usize,
|
|
|
|
/// The number of size per polynomial (a.k.a small polynomials).
|
|
pub size: usize,
|
|
|
|
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
|
pub data: Vec<u8>,
|
|
|
|
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
|
pub ptr: *mut u8,
|
|
}
|
|
|
|
impl ZnxBase {
|
|
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
|
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
|
|
res.data = bytes;
|
|
res
|
|
}
|
|
|
|
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
assert_eq!(n & (n - 1), 0, "n must be a power of two");
|
|
assert!(n > 0, "n must be greater than 0");
|
|
assert!(rows > 0, "rows must be greater than 0");
|
|
assert!(cols > 0, "cols must be greater than 0");
|
|
assert!(size > 0, "size must be greater than 0");
|
|
}
|
|
Self {
|
|
n: n,
|
|
rows: rows,
|
|
cols: cols,
|
|
size: size,
|
|
data: Vec::new(),
|
|
ptr: bytes.as_mut_ptr(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait GetZnxBase {
|
|
fn znx(&self) -> &ZnxBase;
|
|
fn znx_mut(&mut self) -> &mut ZnxBase;
|
|
}
|
|
|
|
pub trait ZnxInfos {
|
|
/// Returns the ring degree of the polynomials.
|
|
fn n(&self) -> usize;
|
|
|
|
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
|
fn log_n(&self) -> usize {
|
|
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
|
}
|
|
|
|
/// Returns the number of rows.
|
|
fn rows(&self) -> usize;
|
|
|
|
/// Returns the number of polynomials in each row.
|
|
fn cols(&self) -> usize;
|
|
|
|
/// Returns the number of size per polynomial.
|
|
fn size(&self) -> usize;
|
|
|
|
/// Returns the total number of small polynomials.
|
|
fn poly_count(&self) -> usize {
|
|
self.rows() * self.cols() * self.size()
|
|
}
|
|
|
|
/// Returns the slice size, which is the offset between
|
|
/// two size of the same column.
|
|
fn sl(&self) -> usize;
|
|
}
|
|
|
|
// pub trait ZnxSliceSize {}
|
|
|
|
//(Jay) TODO: Remove ZnxAlloc
|
|
// pub trait ZnxAlloc<B: Backend>
|
|
// where
|
|
// Self: Sized + ZnxInfos,
|
|
// {
|
|
// type Scalar;
|
|
// fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
|
|
// let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
|
|
// Self::from_bytes(module, rows, cols, size, bytes)
|
|
// }
|
|
|
|
// fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
|
// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
|
|
// res.znx_mut().data = bytes;
|
|
// res
|
|
// }
|
|
|
|
// fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
|
|
|
// fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
|
// }
|
|
|
|
pub trait DataView {
|
|
type D;
|
|
fn data(&self) -> &Self::D;
|
|
}
|
|
|
|
pub trait DataViewMut: DataView {
|
|
fn data_mut(&mut self) -> &mut Self::D;
|
|
}
|
|
|
|
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
|
type Scalar: Copy;
|
|
|
|
/// Returns a non-mutable pointer to the underlying coefficients array.
|
|
fn as_ptr(&self) -> *const Self::Scalar {
|
|
self.data().as_ref().as_ptr() as *const Self::Scalar
|
|
}
|
|
|
|
/// Returns a non-mutable reference to the entire underlying coefficient array.
|
|
fn raw(&self) -> &[Self::Scalar] {
|
|
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
|
}
|
|
|
|
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
assert!(i < self.cols());
|
|
assert!(j < self.size());
|
|
}
|
|
let offset: usize = self.n() * (j * self.cols() + i);
|
|
unsafe { self.as_ptr().add(offset) }
|
|
}
|
|
|
|
/// Returns non-mutable reference to the (i, j)-th small polynomial.
|
|
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
|
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
|
}
|
|
}
|
|
|
|
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
|
/// Returns a mutable pointer to the underlying coefficients array.
|
|
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
|
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
|
|
}
|
|
|
|
/// Returns a mutable reference to the entire underlying coefficient array.
|
|
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
|
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
|
}
|
|
|
|
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
assert!(i < self.cols());
|
|
assert!(j < self.size());
|
|
}
|
|
let offset: usize = self.n() * (j * self.cols() + i);
|
|
unsafe { self.as_mut_ptr().add(offset) }
|
|
}
|
|
|
|
/// Returns mutable reference to the (i, j)-th small polynomial.
|
|
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
|
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
|
}
|
|
}
|
|
|
|
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
|
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
|
|
|
use std::convert::TryFrom;
|
|
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
|
pub trait Num:
|
|
Copy
|
|
+ Default
|
|
+ PartialEq
|
|
+ PartialOrd
|
|
+ Add<Output = Self>
|
|
+ Sub<Output = Self>
|
|
+ Mul<Output = Self>
|
|
+ Div<Output = Self>
|
|
+ Neg<Output = Self>
|
|
+ AddAssign
|
|
{
|
|
const BITS: u32;
|
|
}
|
|
|
|
impl Num for i64 {
|
|
const BITS: u32 = 64;
|
|
}
|
|
|
|
impl Num for i128 {
|
|
const BITS: u32 = 128;
|
|
}
|
|
|
|
impl Num for f64 {
|
|
const BITS: u32 = 64;
|
|
}
|
|
|
|
pub trait ZnxZero: ZnxViewMut
|
|
where
|
|
Self: Sized,
|
|
{
|
|
fn zero(&mut self) {
|
|
unsafe {
|
|
std::ptr::write_bytes(
|
|
self.as_mut_ptr(),
|
|
0,
|
|
self.n() * size_of::<Self::Scalar>() * self.poly_count(),
|
|
);
|
|
}
|
|
}
|
|
|
|
fn zero_at(&mut self, i: usize, j: usize) {
|
|
unsafe {
|
|
std::ptr::write_bytes(
|
|
self.at_mut_ptr(i, j),
|
|
0,
|
|
self.n() * size_of::<Self::Scalar>(),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Blanket implementations
|
|
impl<T> ZnxZero for T where T: ZnxViewMut {}
|
|
// impl<T> ZnxRsh for T where T: ZnxZero {}
|
|
|
|
pub fn switch_degree<S: Copy, DMut: ZnxViewMut<Scalar = S> + ZnxZero, D: ZnxView<Scalar = S>>(
|
|
b: &mut DMut,
|
|
col_b: usize,
|
|
a: &D,
|
|
col_a: usize,
|
|
) {
|
|
let (n_in, n_out) = (a.n(), b.n());
|
|
let (gap_in, gap_out): (usize, usize);
|
|
|
|
if n_in > n_out {
|
|
(gap_in, gap_out) = (n_in / n_out, 1)
|
|
} else {
|
|
(gap_in, gap_out) = (1, n_out / n_in);
|
|
b.zero();
|
|
}
|
|
|
|
let size: usize = min(a.size(), b.size());
|
|
|
|
(0..size).for_each(|i| {
|
|
izip!(
|
|
a.at(col_a, i).iter().step_by(gap_in),
|
|
b.at_mut(col_b, i).iter_mut().step_by(gap_out)
|
|
)
|
|
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
|
});
|
|
}
|
|
|
|
// (Jay)TODO: implement rsh for VecZnx, VecZnxBig
|
|
// pub trait ZnxRsh: ZnxZero {
|
|
// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
|
// rsh(k, log_base2k, self, col, carry)
|
|
// }
|
|
// }
|
|
// pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) {
|
|
// let n: usize = a.n();
|
|
// let size: usize = a.size();
|
|
// let cols: usize = a.cols();
|
|
|
|
// #[cfg(debug_assertions)]
|
|
// {
|
|
// assert!(
|
|
// tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
|
// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
|
// tmp_bytes.len() / size_of::<V::Scalar>(),
|
|
// n,
|
|
// size,
|
|
// );
|
|
// assert_alignement(tmp_bytes.as_ptr());
|
|
// }
|
|
|
|
// let size: usize = a.size();
|
|
// let steps: usize = k / log_base2k;
|
|
|
|
// a.raw_mut().rotate_right(n * steps * cols);
|
|
// (0..cols).for_each(|i| {
|
|
// (0..steps).for_each(|j| {
|
|
// a.zero_at(i, j);
|
|
// })
|
|
// });
|
|
|
|
// let k_rem: usize = k % log_base2k;
|
|
|
|
// if k_rem != 0 {
|
|
// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
|
|
|
|
// unsafe {
|
|
// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
|
// }
|
|
|
|
// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
|
|
// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
|
|
// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
|
|
|
// (steps..size).for_each(|i| {
|
|
// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
|
|
// *xi += *ci << log_base2k_t;
|
|
// *ci = get_base_k_carry(*xi, shift);
|
|
// *xi = (*xi - *ci) >> k_rem_t;
|
|
// });
|
|
// })
|
|
// }
|
|
// }
|
|
|
|
// #[inline(always)]
|
|
// fn get_base_k_carry<T: Num>(x: T, shift: T) -> T {
|
|
// (x << shift) >> shift
|
|
// }
|
|
|
|
// pub fn rsh_tmp_bytes<T: Num>(n: usize) -> usize {
|
|
// n * std::mem::size_of::<T>()
|
|
// }
|
|
|
|
// pub trait ZnxLayout: ZnxInfos {
|
|
// type Scalar;
|
|
|
|
// /// Returns true if the receiver is only borrowing the data.
|
|
// fn borrowing(&self) -> bool {
|
|
// self.znx().data.len() == 0
|
|
// }
|
|
|
|
// /// Returns a non-mutable pointer to the underlying coefficients array.
|
|
// fn as_ptr(&self) -> *const Self::Scalar {
|
|
// self.znx().ptr as *const Self::Scalar
|
|
// }
|
|
|
|
// /// Returns a mutable pointer to the underlying coefficients array.
|
|
// fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
|
// self.znx_mut().ptr as *mut Self::Scalar
|
|
// }
|
|
|
|
// /// Returns a non-mutable reference to the entire underlying coefficient array.
|
|
// fn raw(&self) -> &[Self::Scalar] {
|
|
// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
|
// }
|
|
|
|
// /// Returns a mutable reference to the entire underlying coefficient array.
|
|
// fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
|
// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
|
// }
|
|
|
|
// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
|
// #[cfg(debug_assertions)]
|
|
// {
|
|
// assert!(i < self.cols());
|
|
// assert!(j < self.size());
|
|
// }
|
|
// let offset: usize = self.n() * (j * self.cols() + i);
|
|
// unsafe { self.as_ptr().add(offset) }
|
|
// }
|
|
|
|
// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
|
// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
|
// #[cfg(debug_assertions)]
|
|
// {
|
|
// assert!(i < self.cols());
|
|
// assert!(j < self.size());
|
|
// }
|
|
// let offset: usize = self.n() * (j * self.cols() + i);
|
|
// unsafe { self.as_mut_ptr().add(offset) }
|
|
// }
|
|
|
|
// /// Returns non-mutable reference to the (i, j)-th small polynomial.
|
|
// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
|
// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
|
// }
|
|
|
|
// /// Returns mutable reference to the (i, j)-th small polynomial.
|
|
// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
|
// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
|
// }
|
|
// }
|