diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 2f08633..afac2f8 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index b7d014d..ba48474 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::znx_base::ZnxLayout; +use crate::znx_base::{ZnxView, ZnxViewMut}; use crate::{VecZnx, znx_base::ZnxInfos}; use itertools::izip; use rug::{Assign, Float}; @@ -59,7 +59,7 @@ pub trait Encoding { fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } -impl Encoding for VecZnx { +impl + AsRef<[u8]>> Encoding for VecZnx { fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } @@ -81,7 +81,14 @@ impl Encoding for VecZnx { } } -fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { +fn encode_vec_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + data: &[i64], + log_max: usize, +) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] @@ -132,7 +139,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, } } -fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { @@ -160,7 +167,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat }) } -fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { let size: usize = a.size(); #[cfg(debug_assertions)] { @@ -194,7 +201,15 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo }); } -fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { +fn encode_coeff_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, +) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] @@ -237,7 +252,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz } } -fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); @@ -263,10 +278,9 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{ - Encoding, FFT64, Module, VecZnx, VecZnxOps, - znx_base::{ZnxInfos, ZnxLayout}, - }; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use itertools::izip; use sampling::source::Source; @@ -277,7 +291,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = module.new_vec_znx(2, size); + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -299,7 +313,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = module.new_vec_znx(2, size); + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 73d90c2..7ae1193 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -125,3 +125,29 @@ pub fn alloc_aligned(size: usize) -> Vec { DEFAULTALIGN, ) } + +pub(crate) struct ScratchSpace { + // data: D, +} + +impl ScratchSpace { + fn tmp_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> VecZnxDft { + todo!() + } + + fn tmp_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> VecZnxBig { + todo!() + } + + fn vec_znx_big_normalize_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } + + fn vmp_apply_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } + + fn vmp_apply_dft_to_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { + todo!() + } +} diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 470adcc..34c711a 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,5 +1,5 @@ -use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, alloc_aligned}; +use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -8,68 +8,67 @@ use std::marker::PhantomData; /// /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. /// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - pub inner: ZnxBase, - pub cols_in: usize, - pub cols_out: usize, +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, _marker: PhantomData, } -impl GetZnxBase for MatZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + self.rows } -} -impl ZnxInfos for MatZnxDft {} + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } -impl ZnxSliceSize for MatZnxDft { fn sl(&self) -> usize { self.n() } } -impl ZnxLayout for MatZnxDft { +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { type Scalar = f64; } -impl MatZnxDft { - pub fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let bytes: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self::from_bytes(module, rows, cols_in, cols_out, size, bytes) +impl MatZnxDft { + pub(crate) fn cols_in(&self) -> usize { + self.cols_in } - pub fn from_bytes(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec) -> Self { - let mut mat: MatZnxDft = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes); - mat.znx_mut().data = bytes; - mat + pub(crate) fn cols_out(&self) -> usize { + self.cols_out } +} - pub fn from_bytes_borrow( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> Self { - debug_assert_eq!( - bytes.len(), - Self::bytes_of(module, rows, cols_in, cols_out, size) - ); - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), - cols_in: cols_in, - cols_out: cols_out, - _marker: PhantomData, - } - } - - pub fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { unsafe { crate::ffi::vmp::bytes_of_vmp_pmat( module.ptr, @@ -79,16 +78,62 @@ impl MatZnxDft { } } - pub fn cols_in(&self) -> usize { - self.cols_in + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _marker: PhantomData, + } } - pub fn cols_out(&self) -> usize { - self.cols_out + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _marker: PhantomData, + } } + + // pub fn from_bytes_borrow( + // module: &Module, + // rows: usize, + // cols_in: usize, + // cols_out: usize, + // size: usize, + // bytes: &mut [u8], + // ) -> Self { + // debug_assert_eq!( + // bytes.len(), + // Self::bytes_of(module, rows, cols_in, cols_out, size) + // ); + // Self { + // inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), + // cols_in: cols_in, + // cols_out: cols_out, + // _marker: PhantomData, + // } + // } } -impl MatZnxDft { +impl> MatZnxDft { /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. /// /// # Arguments @@ -123,3 +168,5 @@ impl MatZnxDft { } } } + +pub type MatZnxDftAllocOwned = MatZnxDft, B>; diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 48c3834..62b56a1 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -1,20 +1,19 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; -use crate::znx_base::{ZnxInfos, ZnxLayout}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned, }; -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { +pub trait MatZnxDftAlloc { /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned; fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; @@ -25,17 +24,21 @@ pub trait MatZnxDftOps { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDft; + ) -> MatZnxDftAllocOwned; - fn new_mat_znx_dft_from_bytes_borrow( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> MatZnxDft; + // fn new_mat_znx_dft_from_bytes_borrow( + // &self, + // rows: usize, + // cols_in: usize, + // cols_out: usize, + // size: usize, + // bytes: &mut [u8], + // ) -> MatZnxDft; +} +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; @@ -49,7 +52,14 @@ pub trait MatZnxDftOps { /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]); + fn vmp_prepare_row( + &self, + b: &mut MatZnxDft, + b_row: usize, + b_col_in: usize, + a: &VecZnx, + scratch: &mut ScratchSpace, + ); /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; @@ -64,11 +74,11 @@ pub trait MatZnxDftOps { fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx, + a: &MatZnxDft, b_row: usize, b_col_in: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -80,7 +90,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -89,7 +99,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -133,7 +143,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchSpace); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// @@ -180,16 +190,22 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); + fn vmp_apply_dft_to_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ); } -impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols_in, cols_out, size) +impl MatZnxDftAlloc for Module { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxDftAllocOwned::bytes_of(self, rows, cols_in, cols_out, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - MatZnxDft::::bytes_of(self, rows, cols_in, cols_out, size) + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned { + MatZnxDftAllocOwned::new(self, rows, cols_in, cols_out, size) } fn new_mat_znx_dft_from_bytes( @@ -199,26 +215,28 @@ impl MatZnxDftOps for Module { cols_out: usize, size: usize, bytes: Vec, - ) -> MatZnxDft { - MatZnxDft::::from_bytes(self, rows, cols_in, cols_out, size, bytes) - } - - fn new_mat_znx_dft_from_bytes_borrow( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: &mut [u8], - ) -> MatZnxDft { - MatZnxDft::::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes) + ) -> MatZnxDftAllocOwned { + MatZnxDftAllocOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) } +} +impl MatZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) + >::bytes_of_vec_znx_dft(self, cols_out, size) } - fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) { + fn vmp_prepare_row( + &self, + b: &mut MatZnxDft, + b_row: usize, + b_col_in: usize, + a: &VecZnx, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -249,33 +267,36 @@ impl MatZnxDftOps for Module { b.size(), a.size() ); - assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size())); - assert!(is_aligned(tmp_bytes.as_ptr())) + // assert!( + // tmp_bytes.len() + // >= >::vmp_prepare_row_tmp_bytes(self, a.cols(), a.size()) + // ); + // assert!(is_aligned(tmp_bytes.as_ptr())) } let cols_out: usize = a.cols(); let a_size: usize = a.size(); - let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - - let mut a_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft); + // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); + let mut a_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); } fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes() + self.bytes_of_vec_znx_dft(cols_out, size) + + >::vec_znx_big_normalize_tmp_bytes(self) } fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx, + a: &MatZnxDft, a_row: usize, a_col_in: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ) { #[cfg(debug_assertions)] { @@ -307,24 +328,24 @@ impl MatZnxDftOps for Module { b.size(), a.size() ); - assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); - assert!(is_aligned(tmp_bytes.as_ptr())) + // assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); + // assert!(is_aligned(tmp_bytes.as_ptr())) } let cols_out: usize = b.cols(); let size: usize = b.size(); - let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let mut b_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft); + // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); + let mut b_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, size); Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let mut b_big: VecZnxBig = b_dft.alias_as_vec_znx_big(); + let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size); (0..cols_out).for_each(|i| { - self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i); - self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes); + >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); + self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); }); } - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -369,7 +390,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -433,18 +454,13 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_tmp_bytes( - c.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); + fn vmp_apply_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnx, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(c.n(), self.n()); @@ -464,18 +480,18 @@ impl MatZnxDftOps for Module { a.cols(), b.cols_in() ); - assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_tmp_bytes( - c.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); - assert_alignement(tmp_bytes.as_ptr()); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_tmp_bytes( + // c.size(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft( @@ -488,7 +504,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, - tmp_bytes.as_mut_ptr(), + scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(), ) } } @@ -515,7 +531,13 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { + fn vmp_apply_dft_to_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &MatZnxDft, + scratch: &mut ScratchSpace, + ) { #[cfg(debug_assertions)] { assert_eq!(c.n(), self.n()); @@ -535,20 +557,20 @@ impl MatZnxDftOps for Module { a.cols(), b.cols_in() ); - assert!( - tmp_bytes.len() - >= self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size() - ) - ); - assert_alignement(tmp_bytes.as_ptr()); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_to_dft_tmp_bytes( + // c.cols(), + // c.size(), + // a.cols(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft_to_dft( @@ -560,7 +582,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), + scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(), ) } } @@ -568,9 +590,12 @@ impl MatZnxDftOps for Module { #[cfg(test)] mod tests { + use crate::mat_znx_dft_ops::*; + use crate::vec_znx_big_ops::*; + use crate::vec_znx_dft_ops::*; + use crate::vec_znx_ops::*; use crate::{ - FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - alloc_aligned, znx_base::ZnxLayout, + FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, alloc_aligned, }; use sampling::source::Source; @@ -582,16 +607,19 @@ mod tests { let mat_cols_in: usize = 2; let mat_cols_out: usize = 2; let mat_size: usize = 5; - let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut a_big: VecZnxBig = module.new_vec_znx_big(mat_cols_out, mat_size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut a: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); + let mut b: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut a_big: VecZnxBig<_, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut vmpmat_0: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut vmpmat_1: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + // let mut tmp_bytes: Vec = + // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch = ScratchSpace {}; let mut tmp_bytes: Vec = - alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); + alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); for col_in in 0..mat_cols_in { for row_i in 0..mat_rows { @@ -602,7 +630,7 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes); + module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); @@ -613,11 +641,11 @@ mod tests { assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes); + module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes); + module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch); }); assert_eq!(a.raw(), b.raw()); diff --git a/base2k/src/module.rs b/base2k/src/module.rs index c1799be..0e7d124 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -33,7 +33,7 @@ impl Backend for NTT120 { pub struct Module { pub ptr: *mut MODULE, - pub n: usize, + n: usize, _marker: PhantomData, } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b52c4db..a8b1962 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,16 +1,24 @@ -use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout}; +use crate::znx_base::ZnxViewMut; +use crate::{Backend, Module, VecZnx}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source); - - /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( + fn fill_uniform + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut VecZnx, + col_i: usize, + size: usize, + source: &mut Source, + ); + + /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. + fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + &self, + log_base2k: usize, + a: &mut VecZnx, col_i: usize, log_k: usize, source: &mut Source, @@ -19,10 +27,10 @@ pub trait Sampling { ); /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal( + fn add_normal + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut VecZnx, col_i: usize, log_k: usize, source: &mut Source, @@ -32,22 +40,29 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) { + fn fill_uniform + AsRef<[u8]>>( + &self, + log_base2k: usize, + a: &mut VecZnx, + col_i: usize, + size: usize, + source: &mut Source, + ) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; (0..size).for_each(|j| { - a.at_mut(col_a, j) + a.at_mut(col_i, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); }) } - fn add_dist_f64>( + fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( &self, log_base2k: usize, - a: &mut VecZnx, - col_a: usize, + a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, @@ -63,7 +78,7 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_mut(col_a, limb).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -71,7 +86,7 @@ impl Sampling for Module { *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { - a.at_mut(col_a, limb).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -81,11 +96,11 @@ impl Sampling for Module { } } - fn add_normal( + fn add_normal + AsRef<[u8]>>( &self, log_base2k: usize, - a: &mut VecZnx, - col_a: usize, + a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, sigma: f64, @@ -94,7 +109,7 @@ impl Sampling for Module { self.add_dist_f64( log_base2k, a, - col_a, + col_i, log_k, source, Normal::new(0.0, sigma).unwrap(), @@ -106,7 +121,9 @@ impl Sampling for Module { #[cfg(test)] mod tests { use super::Sampling; - use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout}; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{FFT64, Module, Stats, VecZnx}; use sampling::source::Source; #[test] @@ -120,7 +137,7 @@ mod tests { let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = module.new_vec_znx(cols, size); + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -154,7 +171,7 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = module.new_vec_znx(cols, size); + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index df3e6d1..c5052eb 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,64 +1,59 @@ -use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, GetZnxBase, Module, VecZnx}; +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -pub const SCALAR_ZNX_ROWS: usize = 1; -pub const SCALAR_ZNX_SIZE: usize = 1; +// pub const SCALAR_ZNX_ROWS: usize = 1; +// pub const SCALAR_ZNX_SIZE: usize = 1; -pub struct Scalar { - pub inner: ZnxBase, +pub struct Scalar { + data: D, + n: usize, + cols: usize, } -impl GetZnxBase for Scalar { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for Scalar { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner - } -} - -impl ZnxInfos for Scalar {} - -impl ZnxAlloc for Scalar { - type Scalar = i64; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - Self { - inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), - } + fn rows(&self) -> usize { + 1 } - fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { - debug_assert_eq!( - _rows, SCALAR_ZNX_ROWS, - "rows != {} not supported for Scalar", - SCALAR_ZNX_ROWS - ); - debug_assert_eq!( - _size, SCALAR_ZNX_SIZE, - "rows != {} not supported for Scalar", - SCALAR_ZNX_SIZE - ); - module.n() * cols * std::mem::size_of::() + fn n(&self) -> usize { + self.n } -} -impl ZnxLayout for Scalar { - type Scalar = i64; -} + fn size(&self) -> usize { + 1 + } -impl ZnxSliceSize for Scalar { fn sl(&self) -> usize { self.n() } } -impl Scalar { +impl DataView for Scalar { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for Scalar { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for Scalar { + type Scalar = i64; +} + +impl + AsRef<[u8]>> Scalar { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; @@ -76,38 +71,89 @@ impl Scalar { self.at_mut(col, 0).shuffle(source); } - pub fn alias_as_vec_znx(&self) -> VecZnx { - VecZnx { - inner: ZnxBase { - n: self.n(), - rows: 1, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr() as *mut u8, - }, + // pub fn alias_as_vec_znx(&self) -> VecZnx { + // VecZnx { + // inner: ZnxBase { + // n: self.n(), + // rows: 1, + // cols: 1, + // size: 1, + // data: Vec::new(), + // ptr: self.ptr() as *mut u8, + // }, + // } + // } +} + +impl>> Scalar { + pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { + n * cols * size_of::() + } + + pub(crate) fn new(n: usize, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, } } } -pub trait ScalarOps { +pub type ScalarOwned = Scalar>; + +pub trait ScalarAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; - fn new_scalar(&self, cols: usize) -> Scalar; - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar; - fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; + fn new_scalar(&self, cols: usize) -> ScalarOwned; + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; + // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; } -impl ScalarOps for Module { +impl ScalarAlloc for Module { fn bytes_of_scalar(&self, cols: usize) -> usize { - Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + ScalarOwned::bytes_of::(self.n(), cols) } - fn new_scalar(&self, cols: usize) -> Scalar { - Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + fn new_scalar(&self, cols: usize) -> ScalarOwned { + ScalarOwned::new::(self.n(), cols) } - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar { - Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) - } - fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { + ScalarOwned::new_from_bytes::(self.n(), cols, bytes) } + // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { + // Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + // } } + +// impl ZnxAlloc for Scalar { +// type Scalar = i64; + +// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { +// Self { +// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), +// } +// } + +// fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { +// debug_assert_eq!( +// _rows, SCALAR_ZNX_ROWS, +// "rows != {} not supported for Scalar", +// SCALAR_ZNX_ROWS +// ); +// debug_assert_eq!( +// _size, SCALAR_ZNX_SIZE, +// "rows != {} not supported for Scalar", +// SCALAR_ZNX_SIZE +// ); +// module.n() * cols * std::mem::size_of::() +// } +// } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 6fdb991..09b26d4 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -1,67 +1,97 @@ use std::marker::PhantomData; use crate::ffi::svp; -use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, GetZnxBase, Module}; +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; pub const SCALAR_ZNX_DFT_ROWS: usize = 1; pub const SCALAR_ZNX_DFT_SIZE: usize = 1; -pub struct ScalarZnxDft { - pub inner: ZnxBase, - _marker: PhantomData, +pub struct ScalarZnxDft { + data: D, + n: usize, + cols: usize, + _phantom: PhantomData, } -impl GetZnxBase for ScalarZnxDft { - fn znx(&self) -> &ZnxBase { - &self.inner +impl ZnxInfos for ScalarZnxDft { + fn cols(&self) -> usize { + self.cols } - fn znx_mut(&mut self) -> &mut ZnxBase { - &mut self.inner + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + 1 + } + + fn sl(&self) -> usize { + self.n() } } -impl ZnxInfos for ScalarZnxDft {} - -impl ZnxAlloc for ScalarZnxDft { - type Scalar = u8; - - fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); - Self { - inner: ZnxBase::from_bytes_borrow( - module.n(), - SCALAR_ZNX_DFT_ROWS, - cols, - SCALAR_ZNX_DFT_SIZE, - bytes, - ), - _marker: PhantomData, - } - } - - fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { - debug_assert_eq!( - _rows, SCALAR_ZNX_DFT_ROWS, - "rows != {} not supported for ScalarZnxDft", - SCALAR_ZNX_DFT_ROWS - ); - debug_assert_eq!( - _size, SCALAR_ZNX_DFT_SIZE, - "rows != {} not supported for ScalarZnxDft", - SCALAR_ZNX_DFT_SIZE - ); - unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } +impl DataView for ScalarZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data } } -impl ZnxLayout for ScalarZnxDft { +impl DataViewMut for ScalarZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for ScalarZnxDft { type Scalar = f64; } -impl ZnxSliceSize for ScalarZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols() +impl>, B: Backend> ScalarZnxDft { + pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { + unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } } + + pub(crate) fn new(module: &Module, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes(module: &Module, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } + + // fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + // debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); + // Self { + // inner: ZnxBase::from_bytes_borrow( + // module.n(), + // SCALAR_ZNX_DFT_ROWS, + // cols, + // SCALAR_ZNX_DFT_SIZE, + // bytes, + // ), + // _phantom: PhantomData, + // } + // } } + +pub type ScalarZnxDftOwned = ScalarZnxDft, B>; diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index 4fbe99d..fc56e4e 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,35 +1,52 @@ use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, FFT64, Module, Scalar, ScalarZnxDft, ScalarZnxDftOwned, VecZnx, VecZnxDft}; -pub trait ScalarZnxDftOps { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft; +pub trait ScalarZnxDftAlloc { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft; - fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, a_col: usize, b: &VecZnx, b_col: usize); + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; + // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } -impl ScalarZnxDftOps for Module { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft { - ScalarZnxDft::::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); + fn svp_apply_dft( + &self, + res: &mut VecZnxDft, + res_col: usize, + a: &ScalarZnxDft, + a_col: usize, + b: &VecZnx, + b_col: usize, + ); +} + +impl ScalarZnxDftAlloc for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new(self, cols) } fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { - ScalarZnxDft::::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + ScalarZnxDftOwned::bytes_of(self, cols) } - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft { - ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) } - fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) - } + // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { + // ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + // } +} - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { +impl ScalarZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { unsafe { svp::svp_prepare( self.ptr, @@ -41,11 +58,11 @@ impl ScalarZnxDftOps for Module { fn svp_apply_dft( &self, - res: &mut VecZnxDft, + res: &mut VecZnxDft, res_col: usize, - a: &ScalarZnxDft, + a: &ScalarZnxDft, a_col: usize, - b: &VecZnx, + b: &VecZnx, b_col: usize, ) { unsafe { diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index a1946ab..c6d16b4 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -9,7 +9,7 @@ pub trait Stats { fn std(&self, col_i: usize, log_base2k: usize) -> f64; } -impl Stats for VecZnx { +impl + AsRef<[u8]>> Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b76f93d..3321f8e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,13 +1,10 @@ -use crate::Backend; use crate::DataView; use crate::DataViewMut; -use crate::Module; -use crate::ZnxView; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxRsh, ZnxZero, switch_degree}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut, switch_degree}; use std::{cmp::min, fmt}; // pub const VEC_ZNX_ROWS: usize = 1; @@ -59,7 +56,7 @@ impl DataView for VecZnx { } impl DataViewMut for VecZnx { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -84,7 +81,7 @@ impl + AsRef<[u8]>> VecZnx { return; } - self.inner.size -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -97,7 +94,7 @@ impl + AsRef<[u8]>> VecZnx { } /// Switches degree of from `a.n()` to `self.n()` into `self` - pub fn switch_degree>(&mut self, col: usize, a: &Data, col_a: usize) { + pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { switch_degree(self, col_a, a, col) } @@ -161,7 +158,7 @@ fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -fn normalize>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { +fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); debug_assert!( diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 682493a..72b15d7 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,11 +1,11 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxView}; +use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; const VEC_ZNX_BIG_ROWS: usize = 1; -/// VecZnxBig is Backend dependent, denoted with backend generic `B` +/// VecZnxBig is `Backend` dependent, denoted with backend generic `B` pub struct VecZnxBig { data: D, n: usize, @@ -44,7 +44,7 @@ impl DataView for VecZnxBig { } impl DataViewMut for VecZnxBig { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 5353c32..bb46802 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, DataView, FFT64, Module, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{Backend, DataView, FFT64, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -79,13 +79,13 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `a` to `b` and stores the result on `b`. + /// Subtracts `a` from `b` and stores the result on `b`. fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `b`. + /// Subtracts `b` from `a` and stores the result on `b`. fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `c`. + /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, @@ -96,10 +96,10 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `a` to `b` and stores the result on `b`. + /// Subtracts `a` from `res` and stores the result on `res`. fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); - /// Subtracts `b` to `a` and stores the result on `c`. + /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, @@ -110,7 +110,7 @@ pub trait VecZnxBigOps { b_col: usize, ); - /// Subtracts `b` to `a` and stores the result on `b`. + /// Subtracts `res` from `a` and stores the result on `res`. fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. @@ -129,7 +129,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -160,7 +160,7 @@ impl VecZnxBigAlloc for Module { // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) + VecZnxBigOwned::bytes_of(self, cols, size) } } @@ -208,8 +208,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_add(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -245,7 +261,6 @@ where } } - //(Jay)TODO: check whether definitions sub_ab, sub_ba make sense to you fn vec_znx_big_sub_ab_inplace( &self, res: &mut VecZnxBig, @@ -253,8 +268,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -265,8 +296,24 @@ where a: &VecZnxBig, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub(self, res, res_col, res, res_col, a, a_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -309,8 +356,24 @@ where a: &VecZnx, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub_small_b(self, res, res_col, res, res_col, a, a_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -353,8 +416,24 @@ where a: &VecZnx, a_col: usize, ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_big_sub_small_a(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -391,11 +470,29 @@ where } fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { - Self::vec_znx_big_add_small(self, res, res_col, res, res_col, a, a_col); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } } fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - Self::vec_znx_normalize_tmp_bytes(self) + >::vec_znx_normalize_tmp_bytes(self) } fn vec_znx_big_normalize( @@ -405,14 +502,16 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - tmp_bytes: &mut [u8], + scratch: &mut ScratchSpace, ) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); - assert_alignement(tmp_bytes.as_ptr()); + //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. + // In the FFT backend the tmp sizes are same but will be different in the NTT backend + // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + // assert_alignement(tmp_bytes.as_ptr()); } unsafe { vec_znx::vec_znx_normalize_base2k( @@ -424,7 +523,7 @@ where a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - tmp_bytes.as_mut_ptr(), + scratch.vec_znx_big_normalize_tmp_bytes(self).as_mut_ptr(), ); } } @@ -457,8 +556,21 @@ where } fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_big_automorphism(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index c192486..74b559c 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,11 +1,12 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; -use crate::znx_base::{ZnxAlloc, ZnxInfos}; +use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; const VEC_ZNX_DFT_ROWS: usize = 1; +// VecZnxDft is `Backend` dependent denoted with generic `B` pub struct VecZnxDft { data: D, n: usize, @@ -44,7 +45,7 @@ impl DataView for VecZnxDft { } impl DataViewMut for VecZnxDft { - fn data_mut(&self) -> &mut Self::D { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -84,6 +85,18 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; +impl<'a, D: ?Sized, B> VecZnxDft<&'a mut D, B> { + pub(crate) fn from_mut_slice(data: &'a mut D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + // impl ZnxAlloc for VecZnxDft { // type Scalar = u8; diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index cf2090b..2c1cc97 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,7 +1,5 @@ use crate::VecZnxDftOwned; -use crate::ffi::vec_znx_big; -use crate::ffi::vec_znx_dft; -use crate::znx_base::ZnxAlloc; +use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::znx_base::ZnxInfos; use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; @@ -82,7 +80,7 @@ impl VecZnxDftAlloc for Module { // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, size) + VecZnxDftOwned::bytes_of(&self, cols, size) } } @@ -156,10 +154,10 @@ where #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), + tmp_bytes.len() >= >::vec_znx_idft_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", tmp_bytes.len(), - Self::vec_znx_idft_tmp_bytes(self) + >::vec_znx_idft_tmp_bytes(self) ); assert_alignement(tmp_bytes.as_ptr()) } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 339bc12..6951651 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -86,10 +86,14 @@ pub trait VecZnxOps { ); /// Subtracts the selected column of `a` from the selected column of `res` inplace. + /// + /// res[res_col] -= a[a_col] fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); - // /// Subtracts the selected column of `a` from the selected column of `res` and negates the selected column of `res`. - // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` + /// + /// res[res_col] = a[a_col] - res[res_col] + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); // Negates the selected column of `a` and stores the result in `res_col` of `res`. fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); @@ -136,15 +140,15 @@ pub trait VecZnxOps { impl VecZnxAlloc for Module { //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { - VecZnxOwned::new(self.n(), cols, size) + VecZnxOwned::new::(self.n(), cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnxOwned::bytes_of(self.n(), cols, size) + VecZnxOwned::bytes_of::(self.n(), cols, size) } fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { - VecZnxOwned::new_from_bytes(self.n(), cols, size, bytes) + VecZnxOwned::new_from_bytes::(self.n(), cols, size, bytes) } } @@ -170,7 +174,7 @@ where { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); + assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -190,16 +194,8 @@ where fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { - let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_normalize( - self, - log_base2k, - &mut *a_ptr, - a_col, - &*a_ptr, - a_col, - tmp_bytes, - ); + let a_ptr: *const VecZnx<_> = a; + Self::vec_znx_normalize(self, log_base2k, a, a_col, &*a_ptr, a_col, tmp_bytes); } } @@ -236,8 +232,24 @@ where } fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - Self::vec_znx_add(&self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) } } @@ -274,18 +286,48 @@ where } fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } unsafe { - let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, res, res_col, a, a_col, res, res_col); + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } - // fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { - // unsafe { - // let res_ptr: *mut VecZnx = res as *mut VecZnx; - // Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); - // } - // } + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] @@ -308,7 +350,8 @@ where fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { - Self::vec_znx_negate(self, a, a_col, a, a_col); + let a_ref: *const VecZnx<_> = a; + Self::vec_znx_negate(self, a, a_col, a_ref.as_ref().unwrap(), a_col); } } @@ -333,8 +376,21 @@ where } fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_rotate(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_rotate( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -359,8 +415,21 @@ where } fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - Self::vec_znx_automorphism(self, k, a, a_col, a, a_col); + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } @@ -392,7 +461,7 @@ where self.vec_znx_rotate(-1, buf, 0, a, a_col); } else { switch_degree(bi, res_col, buf, a_col); - self.vec_znx_rotate_inplace(-1, buf, a_col); + >::vec_znx_rotate_inplace(self, -1, buf, a_col); } }) } @@ -414,9 +483,9 @@ where a.iter().enumerate().for_each(|(_, ai)| { switch_degree(res, res_col, ai, a_col); - self.vec_znx_rotate_inplace(-1, res, res_col); + >::vec_znx_rotate_inplace(self, -1, res, res_col); }); - self.vec_znx_rotate_inplace(a.len() as i64, res, res_col); + >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index bf941d4..a7361ad 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -85,26 +85,26 @@ pub trait ZnxInfos { // pub trait ZnxSliceSize {} //(Jay) TODO: Remove ZnxAlloc -pub trait ZnxAlloc -where - Self: Sized + ZnxInfos, -{ - type Scalar; - fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { - let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); - Self::from_bytes(module, rows, cols, size, bytes) - } +// pub trait ZnxAlloc +// where +// Self: Sized + ZnxInfos, +// { +// type Scalar; +// fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { +// let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); +// Self::from_bytes(module, rows, cols, size, bytes) +// } - fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); - res.znx_mut().data = bytes; - res - } +// fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { +// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); +// res.znx_mut().data = bytes; +// res +// } - fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; +// fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; -} +// fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; +// } pub trait DataView { type D; @@ -112,11 +112,11 @@ pub trait DataView { } pub trait DataViewMut: DataView { - fn data_mut(&self) -> &mut Self::D; + fn data_mut(&mut self) -> &mut Self::D; } pub trait ZnxView: ZnxInfos + DataView> { - type Scalar; + type Scalar: Copy; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { @@ -177,11 +177,9 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} use std::convert::TryFrom; -use std::num::TryFromIntError; use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -pub trait IntegerType: +pub trait Num: Copy - + std::fmt::Debug + Default + PartialEq + PartialOrd @@ -190,22 +188,23 @@ pub trait IntegerType: + Mul + Div + Neg - + Shr - + Shl + AddAssign - + TryFrom { const BITS: u32; } -impl IntegerType for i64 { +impl Num for i64 { const BITS: u32 = 64; } -impl IntegerType for i128 { +impl Num for i128 { const BITS: u32 = 128; } +impl Num for f64 { + const BITS: u32 = 64; +} + pub trait ZnxZero: ZnxViewMut where Self: Sized, @@ -231,79 +230,16 @@ where } } -pub trait ZnxRsh: ZnxZero { - fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { - rsh(k, log_base2k, self, col, carry) - } -} - // Blanket implementations impl ZnxZero for T where T: ZnxViewMut {} -impl ZnxRsh for T where T: ZnxZero {} +// impl ZnxRsh for T where T: ZnxZero {} -pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) -where - V::Scalar: IntegerType, -{ - let n: usize = a.n(); - let size: usize = a.size(); - let cols: usize = a.cols(); - - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= rsh_tmp_bytes::(n), - "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", - tmp_bytes.len() / size_of::(), - n, - size, - ); - assert_alignement(tmp_bytes.as_ptr()); - } - - let size: usize = a.size(); - let steps: usize = k / log_base2k; - - a.raw_mut().rotate_right(n * steps * cols); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - a.zero_at(i, j); - }) - }); - - let k_rem: usize = k % log_base2k; - - if k_rem != 0 { - let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); - - unsafe { - std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); - } - - let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); - let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); - let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); - - (steps..size).for_each(|i| { - izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << log_base2k_t; - *ci = get_base_k_carry(*xi, shift); - *xi = (*xi - *ci) >> k_rem_t; - }); - }) - } -} - -#[inline(always)] -fn get_base_k_carry(x: T, shift: T) -> T { - (x << shift) >> shift -} - -pub fn rsh_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() -} - -pub fn switch_degree(b: &mut DMut, col_b: usize, a: &D, col_a: usize) { +pub fn switch_degree + ZnxZero, D: ZnxView>( + b: &mut DMut, + col_b: usize, + a: &D, + col_a: usize, +) { let (n_in, n_out) = (a.n(), b.n()); let (gap_in, gap_out): (usize, usize); @@ -325,6 +261,71 @@ pub fn switch_degree(b: &mut DMut, col_b }); } +// (Jay)TODO: implement rsh for VecZnx, VecZnxBig +// pub trait ZnxRsh: ZnxZero { +// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { +// rsh(k, log_base2k, self, col, carry) +// } +// } +// pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) { +// let n: usize = a.n(); +// let size: usize = a.size(); +// let cols: usize = a.cols(); + +// #[cfg(debug_assertions)] +// { +// assert!( +// tmp_bytes.len() >= rsh_tmp_bytes::(n), +// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", +// tmp_bytes.len() / size_of::(), +// n, +// size, +// ); +// assert_alignement(tmp_bytes.as_ptr()); +// } + +// let size: usize = a.size(); +// let steps: usize = k / log_base2k; + +// a.raw_mut().rotate_right(n * steps * cols); +// (0..cols).for_each(|i| { +// (0..steps).for_each(|j| { +// a.zero_at(i, j); +// }) +// }); + +// let k_rem: usize = k % log_base2k; + +// if k_rem != 0 { +// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + +// unsafe { +// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); +// } + +// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); +// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); +// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + +// (steps..size).for_each(|i| { +// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { +// *xi += *ci << log_base2k_t; +// *ci = get_base_k_carry(*xi, shift); +// *xi = (*xi - *ci) >> k_rem_t; +// }); +// }) +// } +// } + +// #[inline(always)] +// fn get_base_k_carry(x: T, shift: T) -> T { +// (x << shift) >> shift +// } + +// pub fn rsh_tmp_bytes(n: usize) -> usize { +// n * std::mem::size_of::() +// } + // pub trait ZnxLayout: ZnxInfos { // type Scalar; diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 95a935f..ea2b834 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -7,8 +7,8 @@ use crate::{ parameters::Parameters, }; use base2k::{ - Module, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MatZnxDft, - MatZnxDftOps, assert_alignement, + MatZnxDft, MatZnxDftOps, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + VecZnxDft, VecZnxDftOps, VecZnxOps, assert_alignement, }; use sampling::source::Source; use std::collections::HashMap;