diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 8135d85..e3d3247 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf +Subproject commit e3d3247335faccf2b6361213c354cd61b958325e diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 544c096..b76f93d 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,12 +1,16 @@ use crate::Backend; +use crate::DataView; +use crate::DataViewMut; use crate::Module; +use crate::ZnxView; +use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree}; -use std::cmp::min; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree}; +use std::{cmp::min, fmt}; -pub const VEC_ZNX_ROWS: usize = 1; +// pub const VEC_ZNX_ROWS: usize = 1; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. @@ -18,68 +22,57 @@ pub const VEC_ZNX_ROWS: usize = 1; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. -pub struct VecZnx { - pub inner: ZnxBase, +pub struct VecZnx { + data: D, + n: usize, + cols: usize, + size: usize, } -impl GetZnxBase for VecZnx { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnx { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 } -} -impl ZnxInfos for VecZnx {} + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } -impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { self.cols() * self.n() } } -impl ZnxLayout for VecZnx { - type Scalar = i64; -} - -impl ZnxZero for VecZnx {} - -impl ZnxRsh for VecZnx {} - -impl ZnxAlloc for VecZnx { - type Scalar = i64; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - VecZnx { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_ROWS, - "rows != {} not supported for VecZnx", - VEC_ZNX_ROWS - ); - module.n() * cols * size * size_of::() +impl DataView for VecZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -/// Copies the coefficients of `a` on the receiver. -/// Copy is done with the minimum size matching both backing arrays. -/// Panics if the cols do not match. -pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { - assert_eq!(b.cols(), a.cols()); - let data_a: &[i64] = a.raw(); - let data_b: &mut [i64] = b.raw_mut(); - let size = min(data_b.len(), data_a.len()); - data_b[..size].copy_from_slice(&data_a[..size]) +impl DataViewMut for VecZnx { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } } -impl VecZnx { +impl> ZnxView for VecZnx { + type Scalar = i64; +} + +impl + AsRef<[u8]>> VecZnx { + pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { + normalize(log_base2k, self, col, carry) + } + /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -91,12 +84,6 @@ impl VecZnx { return; } - if !self.borrowing() { - self.inner - .data - .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); - } - self.inner.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -109,29 +96,72 @@ impl VecZnx { } } - pub fn copy_from(&mut self, a: &Self) { - copy_vec_znx_from(self, a); - } - - pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { - normalize(log_base2k, self, col, carry) - } - - pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { - switch_degree(a, col_a, self, col) + /// Switches degree of from `a.n()` to `self.n()` into `self` + pub fn switch_degree>(&mut self, col: usize, a: &Data, col_a: usize) { + switch_degree(self, col_a, a, col) } // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); + // pub fn print(&self, n: usize, col: usize) { + // (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); + // } +} + +impl>> VecZnx { + pub(crate) fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() + } + + pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } } } +//(Jay)TODO: Impl. truncate pow2 for Owned Vector + +/// Copies the coefficients of `a` on the receiver. +/// Copy is done with the minimum size matching both backing arrays. +/// Panics if the cols do not match. +pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + assert_eq!(b.cols(), a.cols()); + let data_a: &[i64] = a.raw(); + let data_b: &mut [i64] = b.raw_mut(); + let size = min(data_b.len(), data_a.len()); + data_b[..size].copy_from_slice(&data_a[..size]) +} + +// if !self.borrowing() { +// self.inner +// .data +// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); +// } + fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { +fn normalize>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( @@ -162,3 +192,62 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u }); } } + +// impl ZnxAlloc for VecZnx { +// type Scalar = i64; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { +// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); +// VecZnx { +// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { +// debug_assert_eq!( +// _rows, VEC_ZNX_ROWS, +// "rows != {} not supported for VecZnx", +// VEC_ZNX_ROWS +// ); +// module.n() * cols * size * size_of::() +// } +// } + +impl> fmt::Display for VecZnx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} + +pub type VecZnxOwned = VecZnx>; +pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; +pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 5ba7dde..682493a 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,74 +1,91 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; -use crate::{Backend, FFT64, Module, NTT120}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxView}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; const VEC_ZNX_BIG_ROWS: usize = 1; -pub struct VecZnxBig { - pub inner: ZnxBase, - pub _marker: PhantomData, +/// VecZnxBig is Backend dependent, denoted with backend generic `B` +pub struct VecZnxBig { + data: D, + n: usize, + cols: usize, + size: usize, + _phantom: PhantomData, } -impl GetZnxBase for VecZnxBig { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnxBig { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } + + fn sl(&self) -> usize { + self.cols() * self.n() } } -impl ZnxInfos for VecZnxBig {} - -impl ZnxAlloc for VecZnxBig { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - VecZnxBig { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_BIG_ROWS, - "rows != {} not supported for VecZnxBig", - VEC_ZNX_BIG_ROWS - ); - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } +impl DataView for VecZnxBig { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for VecZnxBig { +impl DataViewMut for VecZnxBig { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxBig { type Scalar = i64; } -impl ZnxLayout for VecZnxBig { - type Scalar = i128; -} +impl>, B: Backend> VecZnxBig { + pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } + } -impl ZnxZero for VecZnxBig {} + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } + } -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - self.n() * self.cols() + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } } } -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - self.n() * 4 * self.cols() - } -} +pub type VecZnxBigOwned = VecZnxBig, B>; -impl ZnxZero for VecZnxBig {} - -impl VecZnxBig { - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); - } -} +// impl VecZnxBig { +// pub fn print(&self, n: usize, col: usize) { +// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); +// } +// } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 8be526e..5353c32 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,10 +1,10 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, DataView, FFT64, Module, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; -pub trait VecZnxBigOps { +pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -18,98 +18,100 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + // /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + // /// + // /// Behavior: the backing array is only borrowed. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of polynomials.. + // /// * `size`: the number of polynomials per column. + // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + // /// + // /// # Panics + // /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxBigOps { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `a` to `b` and stores the result on `c`. fn vec_znx_big_sub( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; @@ -123,44 +125,57 @@ pub trait VecZnxBigOps { fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, tmp_bytes: &mut [u8], ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism( + &self, + k: i64, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, 1, cols, size) +impl VecZnxBigAlloc for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::new(self, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig { - VecZnxBig::from_bytes(self, 1, cols, size, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::new_from_bytes(self, cols, size, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) - } + // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + // VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) + // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, 1, cols, size) + VecZnxBig::bytes_of(self, cols, size) } +} +impl VecZnxBigOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ fn vec_znx_big_add( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -186,20 +201,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_add_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_add(self, res, res_col, a, a_col, res, res_col); } } fn vec_znx_big_sub( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -225,27 +245,38 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + //(Jay)TODO: check whether definitions sub_ab, sub_ba make sense to you + fn vec_znx_big_sub_ab_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_sub(self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_sub_ba_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + Self::vec_znx_big_sub(self, res, res_col, res, res_col, a, a_col); } } fn vec_znx_big_sub_small_b( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { #[cfg(debug_assertions)] @@ -271,20 +302,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_sub_small_b_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnx, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + Self::vec_znx_big_sub_small_b(self, res, res_col, res, res_col, a, a_col); } } fn vec_znx_big_sub_small_a( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, - b: &VecZnxBig, + b: &VecZnxBig, b_col: usize, ) { #[cfg(debug_assertions)] @@ -310,20 +346,25 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_sub_small_a_inplace( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnx, + a_col: usize, + ) { unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_big_sub_small_a(self, res, res_col, a, a_col, res, res_col); } } fn vec_znx_big_add_small( &self, - res: &mut VecZnxBig, + res: &mut VecZnxBig, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { #[cfg(debug_assertions)] @@ -349,11 +390,8 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { - unsafe { - let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - } + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + Self::vec_znx_big_add_small(self, res, res_col, res, res_col, a, a_col); } fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { @@ -363,9 +401,9 @@ impl VecZnxBigOps for Module { fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnxBig, + a: &VecZnxBig, a_col: usize, tmp_bytes: &mut [u8], ) { @@ -391,7 +429,14 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism( + &self, + k: i64, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -411,10 +456,9 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { unsafe { - let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_big_automorphism(self, k, a, a_col, a, a_col); } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index b187645..c192486 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,85 +1,135 @@ -use crate::ffi::vec_znx_dft; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; -use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; +use crate::ffi::vec_znx_dft; +use crate::znx_base::{ZnxAlloc, ZnxInfos}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; + const VEC_ZNX_DFT_ROWS: usize = 1; -pub struct VecZnxDft { - inner: ZnxBase, - pub _marker: PhantomData, +pub struct VecZnxDft { + data: D, + n: usize, + cols: usize, + size: usize, + _phantom: PhantomData, } -impl GetZnxBase for VecZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for VecZnxDft { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } + + fn sl(&self) -> usize { + self.cols() * self.n() } } -impl ZnxInfos for VecZnxDft {} - -impl ZnxAlloc for VecZnxDft { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { - debug_assert_eq!( - _rows, VEC_ZNX_DFT_ROWS, - "rows != {} not supported for VecZnxDft", - VEC_ZNX_DFT_ROWS - ); - unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +impl DataView for VecZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for VecZnxDft { +impl DataViewMut for VecZnxDft { + fn data_mut(&self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxDft { type Scalar = f64; } -impl ZnxZero for VecZnxDft {} - -impl ZnxSliceSize for VecZnxDft { - fn sl(&self) -> usize { - self.n() +impl>, B: Backend> VecZnxDft { + pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } -} -impl VecZnxDft { - pub fn print(&self, n: usize, col: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } } -} -impl VecZnxDft { - /// Cast a [VecZnxDft] into a [VecZnxBig]. - /// The returned [VecZnxBig] shares the backing array - /// with the original [VecZnxDft]. - pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { - assert!( - self.data().len() == 0, - "cannot alias VecZnxDft into VecZnxBig if it owns the data" - ); - VecZnxBig:: { - inner: ZnxBase { - data: Vec::new(), - ptr: self.ptr(), - n: self.n(), - rows: self.rows(), - cols: self.cols(), - size: self.size(), - }, - _marker: PhantomData, + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, } } } + +pub type VecZnxDftOwned = VecZnxDft, B>; + +// impl ZnxAlloc for VecZnxDft { +// type Scalar = u8; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { +// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); +// Self { +// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), +// _marker: PhantomData, +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { +// debug_assert_eq!( +// _rows, VEC_ZNX_DFT_ROWS, +// "rows != {} not supported for VecZnxDft", +// VEC_ZNX_DFT_ROWS +// ); +// unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +// } +// } + +// impl VecZnxDft { +// pub fn print(&self, n: usize, col: usize) { +// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); +// } +// } + +// impl VecZnxDft { +// /// Cast a [VecZnxDft] into a [VecZnxBig]. +// /// The returned [VecZnxBig] shares the backing array +// /// with the original [VecZnxDft]. +// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { +// assert!( +// self.data().len() == 0, +// "cannot alias VecZnxDft into VecZnxBig if it owns the data" +// ); +// VecZnxBig:: { +// inner: ZnxBase { +// data: Vec::new(), +// ptr: self.ptr(), +// n: self.n(), +// rows: self.rows(), +// cols: self.cols(), +// size: self.size(), +// }, +// _marker: PhantomData, +// } +// } +// } diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 679abce..cf2090b 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,15 +1,14 @@ +use crate::VecZnxDftOwned; use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxAlloc; use crate::znx_base::ZnxInfos; -use crate::znx_base::ZnxLayout; -use crate::znx_base::ZnxSliceSize; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement}; +use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; -pub trait VecZnxDftOps { +pub trait VecZnxDftAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -22,20 +21,20 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; + // /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + // /// + // /// Behavior: the backing array is only borrowed. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of cols of the [VecZnxDft]. + // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + // /// + // /// # Panics + // /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -47,37 +46,58 @@ pub trait VecZnxDftOps { /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxDftOps { /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); - fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_idft( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxDft, + a_col: usize, + tmp_bytes: &mut [u8], + ); - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); } -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { - VecZnxDft::::new(&self, 1, cols, size) +impl VecZnxDftAlloc for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft { - VecZnxDft::from_bytes(self, 1, cols, size, bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) - } + // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { + // VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) + // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, 1, cols, size) + VecZnxDft::bytes_of(&self, cols, size) } +} - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { +impl VecZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + fn vec_znx_idft_tmp_a( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &mut VecZnxDft, + a_col: usize, + ) { let min_size: usize = min(res.size(), a.size()); unsafe { @@ -86,7 +106,7 @@ impl VecZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, ) }); @@ -104,7 +124,7 @@ impl VecZnxDftOps for Module { /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { let min_size: usize = min(res.size(), a.size()); unsafe { @@ -125,7 +145,14 @@ impl VecZnxDftOps for Module { } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_idft( + &self, + res: &mut VecZnxBig, + res_col: usize, + a: &VecZnxDft, + a_col: usize, + tmp_bytes: &mut [u8], + ) { #[cfg(debug_assertions)] { assert!( diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 6365ad3..339bc12 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,14 +1,15 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; -use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement}; -pub trait VecZnxOps { +use crate::znx_base::{ZnxInfos, switch_degree}; +use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement}; + +pub trait VecZnxAlloc { /// Allocates a new [VecZnx]. /// /// # Arguments /// /// * `cols`: the number of polynomials. /// * `size`: the number small polynomials per column. - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned; /// Instantiates a new [VecZnx] from a slice of bytes. /// The returned [VecZnx] takes ownership of the slice of bytes. @@ -20,25 +21,28 @@ pub trait VecZnxOps { /// /// # Panic /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; - /// Instantiates a new [VecZnx] from a slice of bytes. - /// The returned [VecZnx] does take ownership of the slice of bytes. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - /// - /// # Panic - /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + // /// Instantiates a new [VecZnx] from a slice of bytes. + // /// The returned [VecZnx] does take ownership of the slice of bytes. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of polynomials. + // /// * `size`: the number small polynomials per column. + // /// + // /// # Panic + // /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. + // fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + // (Jay)TODO /// Returns the number of bytes necessary to allocate /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; +} +pub trait VecZnxOps { /// Returns the minimum number of bytes necessary for normalization. fn vec_znx_normalize_tmp_bytes(&self) -> usize; @@ -46,48 +50,64 @@ pub trait VecZnxOps { fn vec_znx_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, tmp_bytes: &mut [u8], ); /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); - /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. - fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); - /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. - fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. + fn vec_znx_sub( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); - /// Subtracts the selected column of `a` to the selected column of `res`. - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Subtracts the selected column of `a` from the selected column of `res` inplace. + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + // /// Subtracts the selected column of `a` from the selected column of `res` and negates the selected column of `res`. + // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - // Negates the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + // Negates the selected column of `a` and stores the result in `res_col` of `res`. + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); - /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); - /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -95,7 +115,14 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx); + fn vec_znx_split( + &self, + res: &mut Vec>, + res_col: usize, + a: &VecZnx, + a_col: usize, + buf: &mut VecZnx, + ); /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -103,26 +130,29 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize); + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); } -impl VecZnxOps for Module { - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { - VecZnx::new(self, VEC_ZNX_ROWS, cols, size) +impl VecZnxAlloc for Module { + //(Jay)TODO: One must define the Scalar generic param here. + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { + VecZnxOwned::new(self.n(), cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size) + VecZnxOwned::bytes_of(self.n(), cols, size) } - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx { - VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes) - } - - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes) + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { + VecZnxOwned::new_from_bytes(self.n(), cols, size, bytes) } +} +impl VecZnxOps for Module +where + Data: AsRef<[u8]>, + DataMut: AsRef<[u8]> + AsMut<[u8]>, +{ fn vec_znx_normalize_tmp_bytes(&self) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } } @@ -130,9 +160,9 @@ impl VecZnxOps for Module { fn vec_znx_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut VecZnx, res_col: usize, - a: &VecZnx, + a: &VecZnx, a_col: usize, tmp_bytes: &mut [u8], ) { @@ -158,7 +188,7 @@ impl VecZnxOps for Module { } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_normalize( @@ -173,7 +203,15 @@ impl VecZnxOps for Module { } } - fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { + fn vec_znx_add( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -197,14 +235,21 @@ impl VecZnxOps for Module { } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_add(&self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { + fn vec_znx_sub( + &self, + res: &mut VecZnx, + res_col: usize, + a: &VecZnx, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -228,21 +273,21 @@ impl VecZnxOps for Module { } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); + Self::vec_znx_sub(self, res, res_col, a, a_col, res, res_col); } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { - unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - } - } + // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + // unsafe { + // let res_ptr: *mut VecZnx = res as *mut VecZnx; + // Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); + // } + // } - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -261,14 +306,13 @@ impl VecZnxOps for Module { } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_negate(self, a, a_col, a, a_col); } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -288,14 +332,13 @@ impl VecZnxOps for Module { } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_rotate(self, k, a, a_col, a, a_col); } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -315,14 +358,20 @@ impl VecZnxOps for Module { } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); + Self::vec_znx_automorphism(self, k, a, a_col, a, a_col); } } - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) { + fn vec_znx_split( + &self, + res: &mut Vec>, + res_col: usize, + a: &VecZnx, + a_col: usize, + buf: &mut VecZnx, + ) { let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( @@ -348,7 +397,7 @@ impl VecZnxOps for Module { }) } - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize) { + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize) { let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 4cacb70..bf941d4 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -54,11 +54,9 @@ pub trait GetZnxBase { fn znx_mut(&mut self) -> &mut ZnxBase; } -pub trait ZnxInfos: GetZnxBase { +pub trait ZnxInfos { /// Returns the ring degree of the polynomials. - fn n(&self) -> usize { - self.znx().n - } + fn n(&self) -> usize; /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { @@ -66,41 +64,27 @@ pub trait ZnxInfos: GetZnxBase { } /// Returns the number of rows. - fn rows(&self) -> usize { - self.znx().rows - } + fn rows(&self) -> usize; + /// Returns the number of polynomials in each row. - fn cols(&self) -> usize { - self.znx().cols - } + fn cols(&self) -> usize; /// Returns the number of size per polynomial. - fn size(&self) -> usize { - self.znx().size - } - - /// Returns the underlying raw bytes array. - fn data(&self) -> &[u8] { - &self.znx().data - } - - /// Returns a pointer to the underlying raw bytes array. - fn ptr(&self) -> *mut u8 { - self.znx().ptr - } + fn size(&self) -> usize; /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } -} -pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. fn sl(&self) -> usize; } +// pub trait ZnxSliceSize {} + +//(Jay) TODO: Remove ZnxAlloc pub trait ZnxAlloc where Self: Sized + ZnxInfos, @@ -122,22 +106,21 @@ where fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; } -pub trait ZnxLayout: ZnxInfos { - type Scalar; +pub trait DataView { + type D; + fn data(&self) -> &Self::D; +} - /// Returns true if the receiver is only borrowing the data. - fn borrowing(&self) -> bool { - self.znx().data.len() == 0 - } +pub trait DataViewMut: DataView { + fn data_mut(&self) -> &mut Self::D; +} + +pub trait ZnxView: ZnxInfos + DataView> { + type Scalar; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { - self.znx().ptr as *const Self::Scalar - } - - /// Returns a mutable pointer to the underlying coefficients array. - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.znx_mut().ptr as *mut Self::Scalar + self.data().as_ref().as_ptr() as *const Self::Scalar } /// Returns a non-mutable reference to the entire underlying coefficient array. @@ -145,11 +128,6 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } } - /// Returns a mutable reference to the entire underlying coefficient array. - fn raw_mut(&mut self) -> &mut [Self::Scalar] { - unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } - } - /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] @@ -161,6 +139,23 @@ pub trait ZnxLayout: ZnxInfos { unsafe { self.as_ptr().add(offset) } } + /// Returns non-mutable reference to the (i, j)-th small polynomial. + fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } + } +} + +pub trait ZnxViewMut: ZnxView + DataViewMut> { + /// Returns a mutable pointer to the underlying coefficients array. + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar + } + + /// Returns a mutable reference to the entire underlying coefficient array. + fn raw_mut(&mut self) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } + } + /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] @@ -172,17 +167,15 @@ pub trait ZnxLayout: ZnxInfos { unsafe { self.as_mut_ptr().add(offset) } } - /// Returns non-mutable reference to the (i, j)-th small polynomial. - fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } - } - /// Returns mutable reference to the (i, j)-th small polynomial. fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } } +//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known +impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} + use std::convert::TryFrom; use std::num::TryFromIntError; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; @@ -213,7 +206,7 @@ impl IntegerType for i128 { const BITS: u32 = 128; } -pub trait ZnxZero: ZnxLayout +pub trait ZnxZero: ZnxViewMut where Self: Sized, { @@ -238,16 +231,16 @@ where } } -pub trait ZnxRsh: ZnxLayout + ZnxZero -where - Self: Sized, - Self::Scalar: IntegerType, -{ +pub trait ZnxRsh: ZnxZero { fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { rsh(k, log_base2k, self, col, carry) } } +// Blanket implementations +impl ZnxZero for T where T: ZnxViewMut {} +impl ZnxRsh for T where T: ZnxZero {} + pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) where V::Scalar: IntegerType, @@ -310,10 +303,7 @@ pub fn rsh_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) -where - ::Scalar: IntegerType, -{ +pub fn switch_degree(b: &mut DMut, col_b: usize, a: &D, col_a: usize) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); @@ -334,3 +324,64 @@ where .for_each(|(x_in, x_out)| *x_out = *x_in); }); } + +// pub trait ZnxLayout: ZnxInfos { +// type Scalar; + +// /// Returns true if the receiver is only borrowing the data. +// fn borrowing(&self) -> bool { +// self.znx().data.len() == 0 +// } + +// /// Returns a non-mutable pointer to the underlying coefficients array. +// fn as_ptr(&self) -> *const Self::Scalar { +// self.znx().ptr as *const Self::Scalar +// } + +// /// Returns a mutable pointer to the underlying coefficients array. +// fn as_mut_ptr(&mut self) -> *mut Self::Scalar { +// self.znx_mut().ptr as *mut Self::Scalar +// } + +// /// Returns a non-mutable reference to the entire underlying coefficient array. +// fn raw(&self) -> &[Self::Scalar] { +// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } +// } + +// /// Returns a mutable reference to the entire underlying coefficient array. +// fn raw_mut(&mut self) -> &mut [Self::Scalar] { +// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } +// } + +// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. +// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { +// #[cfg(debug_assertions)] +// { +// assert!(i < self.cols()); +// assert!(j < self.size()); +// } +// let offset: usize = self.n() * (j * self.cols() + i); +// unsafe { self.as_ptr().add(offset) } +// } + +// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. +// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { +// #[cfg(debug_assertions)] +// { +// assert!(i < self.cols()); +// assert!(j < self.size()); +// } +// let offset: usize = self.n() * (j * self.cols() + i); +// unsafe { self.as_mut_ptr().add(offset) } +// } + +// /// Returns non-mutable reference to the (i, j)-th small polynomial. +// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { +// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } +// } + +// /// Returns mutable reference to the (i, j)-th small polynomial. +// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { +// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } +// } +// }