mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
wip
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
use crate::Backend;
|
||||
use crate::DataView;
|
||||
use crate::DataViewMut;
|
||||
use crate::Module;
|
||||
use crate::ZnxView;
|
||||
use crate::alloc_aligned;
|
||||
use crate::assert_alignement;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::znx;
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree};
|
||||
use std::cmp::min;
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree};
|
||||
use std::{cmp::min, fmt};
|
||||
|
||||
pub const VEC_ZNX_ROWS: usize = 1;
|
||||
// pub const VEC_ZNX_ROWS: usize = 1;
|
||||
|
||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||
/// Zn\[X\] with [i64] coefficients.
|
||||
@@ -18,68 +22,57 @@ pub const VEC_ZNX_ROWS: usize = 1;
|
||||
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
|
||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||
/// are small polynomials of Zn\[X\].
|
||||
pub struct VecZnx {
|
||||
pub inner: ZnxBase,
|
||||
pub struct VecZnx<D> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl GetZnxBase for VecZnx {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
impl<D> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxInfos for VecZnx {}
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
impl ZnxSliceSize for VecZnx {
|
||||
fn sl(&self) -> usize {
|
||||
self.cols() * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for VecZnx {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl ZnxZero for VecZnx {}
|
||||
|
||||
impl ZnxRsh for VecZnx {}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
||||
type Scalar = i64;
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||
VecZnx {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
debug_assert_eq!(
|
||||
_rows, VEC_ZNX_ROWS,
|
||||
"rows != {} not supported for VecZnx",
|
||||
VEC_ZNX_ROWS
|
||||
);
|
||||
module.n() * cols * size * size_of::<Self::Scalar>()
|
||||
impl<D> DataView for VecZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
|
||||
assert_eq!(b.cols(), a.cols());
|
||||
let data_a: &[i64] = a.raw();
|
||||
let data_b: &mut [i64] = b.raw_mut();
|
||||
let size = min(data_b.len(), data_a.len());
|
||||
data_b[..size].copy_from_slice(&data_a[..size])
|
||||
impl<D> DataViewMut for VecZnx<D> {
|
||||
fn data_mut(&self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnx {
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
||||
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, col, carry)
|
||||
}
|
||||
|
||||
/// Truncates the precision of the [VecZnx] by k bits.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -91,12 +84,6 @@ impl VecZnx {
|
||||
return;
|
||||
}
|
||||
|
||||
if !self.borrowing() {
|
||||
self.inner
|
||||
.data
|
||||
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
|
||||
}
|
||||
|
||||
self.inner.size -= k / log_base2k;
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
@@ -109,29 +96,72 @@ impl VecZnx {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn copy_from(&mut self, a: &Self) {
|
||||
copy_vec_znx_from(self, a);
|
||||
}
|
||||
|
||||
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, col, carry)
|
||||
}
|
||||
|
||||
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
|
||||
switch_degree(a, col_a, self, col)
|
||||
/// Switches degree of from `a.n()` to `self.n()` into `self`
|
||||
pub fn switch_degree<Data: AsRef<[u8]>>(&mut self, col: usize, a: &Data, col_a: usize) {
|
||||
switch_degree(self, col_a, a, col)
|
||||
}
|
||||
|
||||
// Prints the first `n` coefficients of each limb
|
||||
pub fn print(&self, n: usize, col: usize) {
|
||||
(0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
|
||||
// pub fn print(&self, n: usize, col: usize) {
|
||||
// (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
|
||||
// }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||
pub(crate) fn bytes_of<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
|
||||
pub(crate) fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//(Jay)TODO: Impl. truncate pow2 for Owned Vector
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
pub fn copy_vec_znx_from<DataMut, Data>(b: &mut VecZnx<DataMut>, a: &VecZnx<Data>)
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
assert_eq!(b.cols(), a.cols());
|
||||
let data_a: &[i64] = a.raw();
|
||||
let data_b: &mut [i64] = b.raw_mut();
|
||||
let size = min(data_b.len(), data_a.len());
|
||||
data_b[..size].copy_from_slice(&data_a[..size])
|
||||
}
|
||||
|
||||
// if !self.borrowing() {
|
||||
// self.inner
|
||||
// .data
|
||||
// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
|
||||
// }
|
||||
|
||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
fn normalize<D: AsMut<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
debug_assert!(
|
||||
@@ -162,3 +192,62 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
||||
// type Scalar = i64;
|
||||
|
||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||
// VecZnx {
|
||||
// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
// debug_assert_eq!(
|
||||
// _rows, VEC_ZNX_ROWS,
|
||||
// "rows != {} not supported for VecZnx",
|
||||
// VEC_ZNX_ROWS
|
||||
// );
|
||||
// module.n() * cols * size * size_of::<Self::Scalar>()
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
Reference in New Issue
Block a user