use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut}; use itertools::izip; use std::cmp::min; pub struct ZnxBase { /// The ring degree pub n: usize, /// The number of rows (in the third dimension) pub rows: usize, /// The number of polynomials pub cols: usize, /// The number of size per polynomial (a.k.a small polynomials). pub size: usize, /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). pub ptr: *mut u8, } impl ZnxBase { pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); res.data = bytes; res } pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_eq!(n & (n - 1), 0, "n must be a power of two"); assert!(n > 0, "n must be greater than 0"); assert!(rows > 0, "rows must be greater than 0"); assert!(cols > 0, "cols must be greater than 0"); assert!(size > 0, "size must be greater than 0"); } Self { n: n, rows: rows, cols: cols, size: size, data: Vec::new(), ptr: bytes.as_mut_ptr(), } } } pub trait GetZnxBase { fn znx(&self) -> &ZnxBase; fn znx_mut(&mut self) -> &mut ZnxBase; } pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } /// Returns the number of rows. fn rows(&self) -> usize; /// Returns the number of polynomials in each row. fn cols(&self) -> usize; /// Returns the number of size per polynomial. fn size(&self) -> usize; /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } /// Returns the slice size, which is the offset between /// two size of the same column. fn sl(&self) -> usize; } // pub trait ZnxSliceSize {} //(Jay) TODO: Remove ZnxAlloc // pub trait ZnxAlloc // where // Self: Sized + ZnxInfos, // { // type Scalar; // fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { // let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); // Self::from_bytes(module, rows, cols, size, bytes) // } // fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { // let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); // res.znx_mut().data = bytes; // res // } // fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; // fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; // } pub trait DataView { type D; fn data(&self) -> &Self::D; } pub trait DataViewMut: DataView { fn data_mut(&mut self) -> &mut Self::D; } pub trait ZnxView: ZnxInfos + DataView> { type Scalar: Copy; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { self.data().as_ref().as_ptr() as *const Self::Scalar } /// Returns a non-mutable reference to the entire underlying coefficient array. fn raw(&self) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } } /// Returns non-mutable reference to the (i, j)-th small polynomial. fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } } } pub trait ZnxViewMut: ZnxView + DataViewMut> { /// Returns a mutable pointer to the underlying coefficients array. fn as_mut_ptr(&mut self) -> *mut Self::Scalar { self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar } /// Returns a mutable reference to the entire underlying coefficient array. fn raw_mut(&mut self) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } } /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } } /// Returns mutable reference to the (i, j)-th small polynomial. fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } } //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} use std::convert::TryFrom; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; pub trait Num: Copy + Default + PartialEq + PartialOrd + Add + Sub + Mul + Div + Neg + AddAssign { const BITS: u32; } impl Num for i64 { const BITS: u32 = 64; } impl Num for i128 { const BITS: u32 = 128; } impl Num for f64 { const BITS: u32 = 64; } pub trait ZnxZero: ZnxViewMut where Self: Sized, { fn zero(&mut self) { unsafe { std::ptr::write_bytes( self.as_mut_ptr(), 0, self.n() * size_of::() * self.poly_count(), ); } } fn zero_at(&mut self, i: usize, j: usize) { unsafe { std::ptr::write_bytes( self.at_mut_ptr(i, j), 0, self.n() * size_of::(), ); } } } // Blanket implementations impl ZnxZero for T where T: ZnxViewMut {} // impl ZnxRsh for T where T: ZnxZero {} pub fn switch_degree + ZnxZero, D: ZnxView>( b: &mut DMut, col_b: usize, a: &D, col_a: usize, ) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); if n_in > n_out { (gap_in, gap_out) = (n_in / n_out, 1) } else { (gap_in, gap_out) = (1, n_out / n_in); b.zero(); } let size: usize = min(a.size(), b.size()); (0..size).for_each(|i| { izip!( a.at(col_a, i).iter().step_by(gap_in), b.at_mut(col_b, i).iter_mut().step_by(gap_out) ) .for_each(|(x_in, x_out)| *x_out = *x_in); }); } // (Jay)TODO: implement rsh for VecZnx, VecZnxBig // pub trait ZnxRsh: ZnxZero { // fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { // rsh(k, log_base2k, self, col, carry) // } // } // pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) { // let n: usize = a.n(); // let size: usize = a.size(); // let cols: usize = a.cols(); // #[cfg(debug_assertions)] // { // assert!( // tmp_bytes.len() >= rsh_tmp_bytes::(n), // "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", // tmp_bytes.len() / size_of::(), // n, // size, // ); // assert_alignement(tmp_bytes.as_ptr()); // } // let size: usize = a.size(); // let steps: usize = k / log_base2k; // a.raw_mut().rotate_right(n * steps * cols); // (0..cols).for_each(|i| { // (0..steps).for_each(|j| { // a.zero_at(i, j); // }) // }); // let k_rem: usize = k % log_base2k; // if k_rem != 0 { // let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); // unsafe { // std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); // } // let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); // let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); // let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); // (steps..size).for_each(|i| { // izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { // *xi += *ci << log_base2k_t; // *ci = get_base_k_carry(*xi, shift); // *xi = (*xi - *ci) >> k_rem_t; // }); // }) // } // } // #[inline(always)] // fn get_base_k_carry(x: T, shift: T) -> T { // (x << shift) >> shift // } // pub fn rsh_tmp_bytes(n: usize) -> usize { // n * std::mem::size_of::() // } // pub trait ZnxLayout: ZnxInfos { // type Scalar; // /// Returns true if the receiver is only borrowing the data. // fn borrowing(&self) -> bool { // self.znx().data.len() == 0 // } // /// Returns a non-mutable pointer to the underlying coefficients array. // fn as_ptr(&self) -> *const Self::Scalar { // self.znx().ptr as *const Self::Scalar // } // /// Returns a mutable pointer to the underlying coefficients array. // fn as_mut_ptr(&mut self) -> *mut Self::Scalar { // self.znx_mut().ptr as *mut Self::Scalar // } // /// Returns a non-mutable reference to the entire underlying coefficient array. // fn raw(&self) -> &[Self::Scalar] { // unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } // } // /// Returns a mutable reference to the entire underlying coefficient array. // fn raw_mut(&mut self) -> &mut [Self::Scalar] { // unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } // } // /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. // fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { // #[cfg(debug_assertions)] // { // assert!(i < self.cols()); // assert!(j < self.size()); // } // let offset: usize = self.n() * (j * self.cols() + i); // unsafe { self.as_ptr().add(offset) } // } // /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. // fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { // #[cfg(debug_assertions)] // { // assert!(i < self.cols()); // assert!(j < self.size()); // } // let offset: usize = self.n() * (j * self.cols() + i); // unsafe { self.as_mut_ptr().add(offset) } // } // /// Returns non-mutable reference to the (i, j)-th small polynomial. // fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { // unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } // } // /// Returns mutable reference to the (i, j)-th small polynomial. // fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { // unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } // } // }