use itertools::izip; use crate::DataView; use crate::DataViewMut; use crate::ScalarZnx; use crate::Scratch; use crate::ZnxSliceSize; use crate::ZnxZero; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use std::{cmp::min, fmt}; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array /// in the memory. /// /// # Example /// /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. pub struct VecZnx { pub data: D, pub n: usize, pub cols: usize, pub size: usize, } impl ZnxInfos for VecZnx { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { self.size } } impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { self.n() * self.cols() } } impl DataView for VecZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for VecZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl> ZnxView for VecZnx { type Scalar = i64; } impl> VecZnx { pub fn rsh_scratch_space(n: usize) -> usize { n * std::mem::size_of::() } } impl + AsRef<[u8]>> VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments /// /// * `basek`: the base two logarithm of the coefficients decomposition. /// * `k`: the number of bits of precision to drop. pub fn trunc_pow2(&mut self, basek: usize, k: usize, col: usize) { if k == 0 { return; } self.size -= k / basek; let k_rem: usize = k % basek; if k_rem != 0 { let mask: i64 = ((1 << (basek - k_rem - 1)) - 1) << k_rem; self.at_mut(col, self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } } pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { let n: usize = self.n(); let cols: usize = self.cols(); let size: usize = self.size(); let steps: usize = k / basek; self.raw_mut().rotate_right(n * steps * cols); (0..cols).for_each(|i| { (0..steps).for_each(|j| { self.zero_at(i, j); }) }); let k_rem: usize = k % basek; if k_rem != 0 { let (carry, _) = scratch.tmp_slice::(n); let shift = i64::BITS as usize - k_rem; (0..cols).for_each(|i| { carry.fill(0); (steps..size).for_each(|j| { izip!(carry.iter_mut(), self.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << basek; *ci = (*xi << shift) >> shift; *xi = (*xi - *ci) >> k_rem; }); }); }) } } } impl>> VecZnx { pub(crate) fn bytes_of(n: usize, cols: usize, size: usize) -> usize { n * cols * size * size_of::() } pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); Self { data: data.into(), n, cols, size, } } pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of::(n, cols, size)); Self { data: data.into(), n, cols, size, } } } impl VecZnx { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, cols, size, } } pub fn to_scalar_znx(self) -> ScalarZnx { debug_assert_eq!( self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols ); ScalarZnx { data: self.data, n: self.n, cols: self.cols, } } } /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) where DataMut: AsMut<[u8]> + AsRef<[u8]>, Data: AsRef<[u8]>, { assert_eq!(b.cols(), a.cols()); let data_a: &[i64] = a.raw(); let data_b: &mut [i64] = b.raw_mut(); let size = min(data_b.len(), data_a.len()); data_b[..size].copy_from_slice(&data_a[..size]) } #[allow(dead_code)] fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } #[allow(dead_code)] fn normalize + AsRef<[u8]>>(basek: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( tmp_bytes.len() >= normalize_tmp_bytes(n), "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})", tmp_bytes.len(), n, ); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()) } let carry_i64: &mut [i64] = cast_mut(tmp_bytes); unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); (0..a.size()).rev().for_each(|i| { znx::znx_normalize( n as u64, basek as u64, a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), ) }); } } impl VecZnx where VecZnx: VecZnxToMut + ZnxInfos, { /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. pub fn extract_column(&mut self, self_col: usize, a: &R, a_col: usize) where R: VecZnxToRef + ZnxInfos, { #[cfg(debug_assertions)] { assert!(self_col < self.cols()); assert!(a_col < a.cols()); } let min_size: usize = self.size.min(a.size()); let max_size: usize = self.size; let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); let a_ref: VecZnx<&[u8]> = a.to_ref(); (0..min_size).for_each(|i: usize| { self_mut .at_mut(self_col, i) .copy_from_slice(a_ref.at(a_col, i)); }); (min_size..max_size).for_each(|i| { self_mut.zero_at(self_col, i); }); } } impl> fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "VecZnx(n={}, cols={}, size={})", self.n, self.cols, self.size )?; for col in 0..self.cols { writeln!(f, "Column {}:", col)?; for size in 0..self.size { let coeffs = self.at(col, size); write!(f, " Size {}: [", size)?; let max_show = 100; let show_count = coeffs.len().min(max_show); for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", coeff)?; } if coeffs.len() > max_show { write!(f, ", ... ({} more)", coeffs.len() - max_show)?; } writeln!(f, "]")?; } } Ok(()) } } pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; pub trait VecZnxToRef { fn to_ref(&self) -> VecZnx<&[u8]>; } pub trait VecZnxToMut: VecZnxToRef { fn to_mut(&mut self) -> VecZnx<&mut [u8]>; } impl VecZnxToMut for VecZnx> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, size: self.size, } } } impl VecZnxToRef for VecZnx> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, cols: self.cols, size: self.size, } } } impl VecZnxToMut for VecZnx<&mut [u8]> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: self.size, } } } impl VecZnxToRef for VecZnx<&mut [u8]> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: self.size, } } } impl VecZnxToRef for VecZnx<&[u8]> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: self.size, } } } impl VecZnx where VecZnx: VecZnxToRef, { pub fn clone(&self) -> VecZnx> { let self_ref: VecZnx<&[u8]> = self.to_ref(); VecZnx { data: self_ref.data.to_vec(), n: self_ref.n, cols: self_ref.cols, size: self_ref.size, } } }