From ffa363804b062cc0cceffd4eadbb18950a8b75bd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 5 May 2025 17:35:35 +0200 Subject: [PATCH] rework as discussed --- base2k/examples/rlwe_encrypt.rs | 43 +- base2k/examples/vmp.rs | 78 ---- base2k/src/lib.rs | 8 + base2k/src/mat_znx_dft.rs | 104 +++-- base2k/src/mat_znx_dft_ops.rs | 649 +++++++++---------------------- base2k/src/sampling.rs | 63 ++- base2k/src/scalar_znx.rs | 100 +++-- base2k/src/scalar_znx_dft.rs | 92 +++-- base2k/src/scalar_znx_dft_ops.rs | 70 ++-- base2k/src/vec_znx.rs | 77 +++- base2k/src/vec_znx_big.rs | 71 +++- base2k/src/vec_znx_big_ops.rs | 363 +++++++++-------- base2k/src/vec_znx_dft.rs | 65 +++- base2k/src/vec_znx_dft_ops.rs | 123 +++--- base2k/src/vec_znx_ops.rs | 371 ++++++++++++------ base2k/src/znx_base.rs | 30 +- 16 files changed, 1154 insertions(+), 1153 deletions(-) delete mode 100644 base2k/examples/vmp.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 742dcea..b55efba 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,7 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps, - VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, + Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, + VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -13,24 +14,23 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut scratch = - ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s = module.new_scalar(1); + let mut s: Scalar> = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft = module.new_scalar_znx_dft(s.cols()); + let mut s_dft: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct = module.new_vec_znx( + let mut ct: VecZnx> = module.new_vec_znx( 2, // Number of columns ct_size, // Number of small poly per column ); @@ -38,12 +38,10 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); - // Scratch space for DFT values - let scratch = scratch.borrow(); - let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size); + let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); // Applies DFT(ct[1]) * DFT(s) - module.svp_apply_dft( + module.svp_apply( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) @@ -53,11 +51,10 @@ fn main() { ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - // Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary - as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); + let mut buf_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_size); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column let mut m = module.new_vec_znx( @@ -68,8 +65,7 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::()); - m.normalize(log_base2k, 0, tmp_bytes_norm); + module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow()); // m - BIG(ct[1] * s) module.vec_znx_big_sub_small_b_inplace( @@ -82,9 +78,12 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) - &buf_big, 0, // Selects the first column of buf_big - scratch, + log_base2k, + &mut ct, + 0, // Selects the first column of ct (ct[0]) + &buf_big, + 0, // Selects the first column of buf_big + scratch.borrow(), ); // Add noise to ct[0] @@ -104,7 +103,7 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.svp_apply_dft( + module.svp_apply( &mut buf_dft, 0, // Selects the first column of res. &s_dft, @@ -114,14 +113,14 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) let mut res = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs deleted file mode 100644 index 36943f7..0000000 --- a/base2k/examples/vmp.rs +++ /dev/null @@ -1,78 +0,0 @@ -// use base2k::{ -// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, -// ZnxInfos, ZnxLayout, alloc_aligned, -// }; - -fn main() { - // let log_n: i32 = 5; - // let n: usize = 1 << log_n; - - // let module: Module = Module::::new(n); - // let log_base2k: usize = 15; - - // let a_cols: usize = 2; - // let a_size: usize = 5; - - // let log_k: usize = log_base2k * a_size - 5; - - // let mat_rows: usize = a_size; - // let mat_cols_in: usize = a_cols; - // let mat_cols_out: usize = 2; - // let mat_size: usize = a_size + 1; - - // let mut tmp_bytes_vmp: Vec = alloc_aligned( - // module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) - // | module.vmp_apply_dft_tmp_bytes( - // a_size, - // a_size, - // mat_rows, - // mat_cols_in, - // mat_cols_out, - // mat_size, - // ), - // ); - - // let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); - - // let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); - - // (0..a_cols).for_each(|i| { - // let mut values: Vec = vec![i64::default(); n]; - // values[1 + i] = (1 << log_base2k) + 1; - // a.encode_vec_i64(i, log_base2k, log_k, &values, 32); - // a.normalize(log_base2k, i, &mut tmp_bytes_vmp); - // a.print(n, i); - // println!(); - // }); - - // let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - // (0..a.size()).for_each(|row_i| { - // let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - // (0..mat_cols_out).for_each(|j| { - // tmp.at_mut(j, row_i)[1 + j] = 1 as i64; - // }); - // (0..mat_cols_in).for_each(|j| { - // module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); - // }) - // }); - - // let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); - // module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); - - // let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); - // let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - // (0..mat_cols_out).for_each(|i| { - // module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - // module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); - - // let mut values_res: Vec = vec![i64::default(); n]; - // res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); - // res.print(n, i); - // println!(); - // println!("{:?}", values_res); - // println!(); - // }); - - // module.free(); -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 38d6b4e..f3b2525 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -215,4 +215,12 @@ impl Scratch { Self::new(rem_slice), ) } + + pub fn tmp_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size)); + ( + VecZnx::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 7a39dd1..1f18b48 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], @@ -8,17 +8,17 @@ use std::marker::PhantomData; /// /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. /// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { +pub struct MatZnxDft { data: D, n: usize, size: usize, rows: usize, cols_in: usize, cols_out: usize, - _marker: PhantomData, + _phantom: PhantomData, } -impl ZnxInfos for MatZnxDft { +impl ZnxInfos for MatZnxDft { fn cols(&self) -> usize { self.cols_in } @@ -34,20 +34,22 @@ impl ZnxInfos for MatZnxDft { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for MatZnxDft { fn sl(&self) -> usize { - self.n() + self.n() * self.cols_out() } } -impl DataView for MatZnxDft { +impl DataView for MatZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for MatZnxDft { +impl DataViewMut for MatZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -57,7 +59,7 @@ impl> ZnxView for MatZnxDft { type Scalar = f64; } -impl MatZnxDft { +impl MatZnxDft { pub(crate) fn cols_in(&self) -> usize { self.cols_in } @@ -87,7 +89,7 @@ impl>, B: Backend> MatZnxDft { rows, cols_in, cols_out, - _marker: PhantomData, + _phantom: PhantomData, } } @@ -108,7 +110,7 @@ impl>, B: Backend> MatZnxDft { rows, cols_in, cols_out, - _marker: PhantomData, + _phantom: PhantomData, } } } @@ -151,28 +153,80 @@ impl> MatZnxDft { pub type MatZnxDftAllocOwned = MatZnxDft, B>; -impl MatZnxDft, B> { - pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { MatZnxDft { data: self.data.as_mut_slice(), n: self.n, - size: self.size, rows: self.rows, cols_in: self.cols_in, cols_out: self.cols_out, - _marker: PhantomData, - } - } - - pub fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_slice(), - n: self.n, size: self.size, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - _marker: PhantomData, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, } } } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 658ff5d..9b79a2c 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,11 +2,11 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, + VecZnxDftToRef, }; -pub trait MatZnxDftAlloc { +pub trait MatZnxDftAlloc { /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments @@ -28,43 +28,10 @@ pub trait MatZnxDftAlloc { } pub trait MatZnxDftScratch { - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] - fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes( - &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes( + fn vmp_apply_tmp_bytes( &self, - c_cols: usize, - c_size: usize, - a_cols: usize, + res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, @@ -75,43 +42,7 @@ pub trait MatZnxDftScratch { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { - /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the row of the [MatZnxDft] to prepare. - /// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft]. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row( - &self, - b: &mut MatZnxDft, - b_row: usize, - b_col_in: usize, - a: &VecZnx, - scratch: &mut Scratch, - ); - - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. - /// * `a`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, - b_row: usize, - b_col_in: usize, - scratch: &mut Scratch, - ); - +pub trait MatZnxDftOps { /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// /// # Arguments @@ -121,7 +52,10 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef; /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -130,33 +64,10 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut Scratch); + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -183,13 +94,11 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - scratch: &mut Scratch, - ); + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; } impl MatZnxDftAlloc for Module { @@ -213,40 +122,10 @@ impl MatZnxDftAlloc for Module { } } -impl MatZnxDftScratch for Module { - fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) - } - - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) - + ::vec_znx_big_normalize_tmp_bytes(self) - } - - fn vmp_apply_dft_tmp_bytes( +impl MatZnxDftScratch for Module { + fn vmp_apply_tmp_bytes( &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - c_size as u64, - a_size as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - c_cols: usize, - c_size: usize, - a_cols: usize, + res_size: usize, a_size: usize, b_rows: usize, b_cols_in: usize, @@ -256,8 +135,8 @@ impl MatZnxDftScratch for Module { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - (c_size * c_cols) as u64, - (a_size * a_cols) as u64, + (res_size * b_cols_out) as u64, + (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, (b_size * b_cols_out) as u64, ) as usize @@ -265,152 +144,43 @@ impl MatZnxDftScratch for Module { } } -impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { - fn vmp_prepare_row( - &self, - b: &mut MatZnxDft<&mut [u8], FFT64>, - b_row: usize, - b_col_in: usize, - a: &VecZnx<&[u8]>, - scratch: &mut Scratch, - ) { +impl MatZnxDftOps for Module { + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res: MatZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( a.cols(), - b.cols_out(), - "a.cols(): {} != b.cols_out(): {}", + res.cols_out(), + "a.cols(): {} != res.cols_out(): {}", a.cols(), - b.cols_out() + res.cols_out() ); assert!( - b_row < b.rows(), - "b_row: {} >= b.rows(): {}", - b_row, - b.rows() + res_row < res.rows(), + "res_row: {} >= res.rows(): {}", + res_row, + res.rows() ); assert!( - b_col_in < b.cols_in(), - "b_col_in: {} >= b.cols_in(): {}", - b_col_in, - b.cols_in() + res_col_in < res.cols_in(), + "res_col_in: {} >= res.cols_in(): {}", + res_col_in, + res.cols_in() ); assert_eq!( - b.size(), + res.size(), a.size(), - "b.size(): {} != a.size(): {}", - b.size(), - a.size() - ); - // assert!( - // tmp_bytes.len() - // >= >::vmp_prepare_row_tmp_bytes(self, a.cols(), a.size()) - // ); - // assert!(is_aligned(tmp_bytes.as_ptr())) - } - - let cols_out: usize = a.cols(); - let a_size: usize = a.size(); - - // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size); - (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref()); - } - - fn vmp_extract_row( - &self, - log_base2k: usize, - b: &mut VecZnx<&mut [u8]>, - a: &MatZnxDft<&[u8], FFT64>, - a_row: usize, - a_col_in: usize, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - b.cols(), - a.cols_out(), - "b.cols(): {} != a.cols_out(): {}", - b.cols(), - a.cols_out() - ); - assert!( - a_row < a.rows(), - "a_row: {} >= a.rows(): {}", - a_row, - a.rows() - ); - assert!( - a_col_in < a.cols_in(), - "a_col_in: {} >= a.cols_in(): {}", - a_col_in, - a.cols_in() - ); - assert_eq!( - b.size(), - a.size(), - "b.size(): {} != a.size(): {}", - b.size(), - a.size() - ); - // assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); - // assert!(is_aligned(tmp_bytes.as_ptr())) - } - - let cols_out: usize = b.cols(); - let size: usize = b.size(); - - // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size); - Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); - (0..cols_out).for_each(|i| { - >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); - self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); - }); - } - - fn vmp_prepare_row_dft( - &self, - b: &mut MatZnxDft<&mut [u8], FFT64>, - b_row: usize, - b_col_in: usize, - a: &VecZnxDft<&[u8], FFT64>, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - a.cols(), - b.cols_out(), - "a.cols(): {} != b.cols_out(): {}", - a.cols(), - b.cols_out() - ); - assert!( - b_row < b.rows(), - "b_row: {} >= b.rows(): {}", - b_row, - b.rows() - ); - assert!( - b_col_in < b.cols_in(), - "b_col_in: {} >= b.cols_in(): {}", - b_col_in, - b.cols_in() - ); - assert_eq!( - b.size(), - a.size(), - "b.size(): {} != a.size(): {}", - b.size(), + "res.size(): {} != a.size(): {}", + res.size(), a.size() ); } @@ -418,31 +188,32 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { unsafe { vmp::vmp_prepare_row_dft( self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, + res.as_mut_ptr() as *mut vmp::vmp_pmat_t, a.as_ptr() as *const vec_znx_dft_t, - (b_row * b.cols_in() + b_col_in) as u64, - (b.rows() * b.cols_in()) as u64, - (b.size() * b.cols_out()) as u64, + (res_row * res.cols_in() + res_col_in) as u64, + (res.rows() * res.cols_in()) as u64, + (res.size() * res.cols_out()) as u64, ); } } - fn vmp_extract_row_dft( - &self, - b: &mut VecZnxDft<&mut [u8], FFT64>, - a: &MatZnxDft<&[u8], FFT64>, - a_row: usize, - a_col_in: usize, - ) { + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: MatZnxDft<&[u8], _> = a.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( - b.cols(), + res.cols(), a.cols_out(), - "b.cols(): {} != a.cols_out(): {}", - b.cols(), + "res.cols(): {} != a.cols_out(): {}", + res.cols(), a.cols_out() ); assert!( @@ -458,17 +229,17 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { a.cols_in() ); assert_eq!( - b.size(), + res.size(), a.size(), - "b.size(): {} != a.size(): {}", - b.size(), + "res.size(): {} != a.size(): {}", + res.size(), a.size() ); } unsafe { vmp::vmp_extract_row_dft( self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, + res.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp::vmp_pmat_t, (a_row * a.cols_in() + a_col_in) as u64, (a.rows() * a.cols_in()) as u64, @@ -477,23 +248,26 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { } } - fn vmp_apply_dft( - &self, - c: &mut VecZnxDft<&mut [u8], FFT64>, - a: &VecZnx<&[u8]>, - b: &MatZnxDft<&[u8], FFT64>, - scratch: &mut Scratch, - ) { + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(c.n(), self.n()); + assert_eq!(res.n(), self.n()); assert_eq!(b.n(), self.n()); assert_eq!(a.n(), self.n()); assert_eq!( - c.cols(), + res.cols(), b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), b.cols_out() ); assert_eq!( @@ -503,37 +277,23 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { a.cols(), b.cols_in() ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_tmp_bytes( - // c.size(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vmp_apply_dft_tmp_bytes( - self, - c.size(), + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes( + res.size(), a.size(), b.rows(), b.cols_in(), b.cols_out(), b.size(), )); - unsafe { - vmp::vmp_apply_dft( + vmp::vmp_apply_dft_to_dft( self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - (c.size() * c.cols()) as u64, - a.as_ptr(), + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, (a.size() * a.cols()) as u64, - a.n() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, @@ -541,164 +301,131 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { ) } } - - fn vmp_apply_dft_to_dft( - &self, - c: &mut VecZnxDft<&mut [u8], FFT64>, - a: &VecZnxDft<&[u8], FFT64>, - b: &MatZnxDft<&[u8], FFT64>, - scratch: &mut Scratch, - ) { - { - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - c.cols(), - b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_to_dft_tmp_bytes( - // c.cols(), - // c.size(), - // a.cols(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.poly_count() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.poly_count() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - } } - #[cfg(test)] mod tests { - use crate::ScratchOwned; - use crate::mat_znx_dft_ops::*; - use crate::vec_znx_big_ops::*; - use crate::vec_znx_dft_ops::*; - use crate::vec_znx_ops::*; use crate::{ - FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, alloc_aligned, + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, }; use sampling::source::Source; + use super::{MatZnxDftAlloc, MatZnxDftScratch}; + #[test] - fn vmp_prepare_row_dft() { + fn vmp_prepare_row() { let module: Module = Module::::new(16); let log_base2k: usize = 8; let mat_rows: usize = 4; let mat_cols_in: usize = 2; let mat_cols_out: usize = 2; let mat_size: usize = 5; - let mut a: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); - let mut b: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size); - let mut a_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut a_big: VecZnxBig<_, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - let mut b_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut vmpmat_0: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut vmpmat_1: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - // let mut tmp_bytes: Vec = - // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - let mut scratch = ScratchOwned::new( - 2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()), - ); - let mut tmp_bytes: Vec = - alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); + let mut a: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut mat: MatZnxDft, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); for col_in in 0..mat_cols_in { for row_i in 0..mat_rows { let mut source: Source = Source::new([0u8; 32]); - (0..mat_cols_out).for_each(|col_out| { module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source); module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - - module.vmp_prepare_row( - &mut vmpmat_0.to_mut(), - row_i, - col_in, - &a.to_ref(), - scratch.borrow(), - ); - - // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref()); - assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); - - // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in); + module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); + module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); - - // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row( - log_base2k, - &mut b.to_mut(), - &vmpmat_0.to_ref(), - row_i, - col_in, - scratch.borrow(), - ); - - (0..mat_cols_out).for_each(|col_out| { - module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize( - log_base2k, - &mut a.to_mut(), - col_out, - &a_big.to_ref(), - col_out, - scratch.borrow(), - ); - }); - - assert_eq!(a.raw(), b.raw()); } } module.free(); } + + #[test] + fn vmp_apply() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let log_base2k: usize = 15; + let a_size: usize = 5; + let mat_size: usize = 6; + let res_size: usize = 5; + + [1, 2].iter().for_each(|in_cols| { + [1, 2].iter().for_each(|out_cols| { + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + let res_cols: usize = mat_cols_out; + + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); + + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + a.at_mut(i, 2)[i + 1] = 1; + }); + + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft(&mut a_dft, i, &a, i); + }); + + module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + (0..mat_cols_out).for_each(|col_i| { + let mut res_want_vi64: Vec = vec![i64::default(); n]; + (0..a_cols).for_each(|i| { + res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + }); + res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); + }); + }); + + module.free(); + } } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index a8b1962..b254286 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,53 +1,47 @@ use crate::znx_base::ZnxViewMut; -use crate::{Backend, Module, VecZnx}; +use crate::{Backend, Module, VecZnx, VecZnxToMut}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - size: usize, - source: &mut Source, - ); + fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) + where + A: VecZnxToMut; /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + fn add_dist_f64>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ); + ) where + A: VecZnxToMut; /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal + AsRef<[u8]>>( + fn add_normal( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64, - ); + ) where + A: VecZnxToMut; } impl Sampling for Module { - fn fill_uniform + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - size: usize, - source: &mut Source, - ) { + fn fill_uniform(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; @@ -58,16 +52,19 @@ impl Sampling for Module { }) } - fn add_dist_f64 + AsRef<[u8]>, D: Distribution>( + fn add_dist_f64>( &self, log_base2k: usize, - a: &mut VecZnx, + a: &mut A, col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, - ) { + ) where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", @@ -96,16 +93,10 @@ impl Sampling for Module { } } - fn add_normal + AsRef<[u8]>>( - &self, - log_base2k: usize, - a: &mut VecZnx, - col_i: usize, - log_k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) { + fn add_normal(&self, log_base2k: usize, a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) + where + A: VecZnxToMut, + { self.add_dist_f64( log_base2k, a, diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index c5052eb..acdac8c 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -1,13 +1,10 @@ use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -// pub const SCALAR_ZNX_ROWS: usize = 1; -// pub const SCALAR_ZNX_SIZE: usize = 1; - pub struct Scalar { data: D, n: usize, @@ -30,7 +27,9 @@ impl ZnxInfos for Scalar { fn size(&self) -> usize { 1 } +} +impl ZnxSliceSize for Scalar { fn sl(&self) -> usize { self.n() } @@ -70,19 +69,6 @@ impl + AsRef<[u8]>> Scalar { .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } - - // pub fn alias_as_vec_znx(&self) -> VecZnx { - // VecZnx { - // inner: ZnxBase { - // n: self.n(), - // rows: 1, - // cols: 1, - // size: 1, - // data: Vec::new(), - // ptr: self.ptr() as *mut u8, - // }, - // } - // } } impl>> Scalar { @@ -116,7 +102,6 @@ pub trait ScalarAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; fn new_scalar(&self, cols: usize) -> ScalarOwned; fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; - // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; } impl ScalarAlloc for Module { @@ -129,31 +114,62 @@ impl ScalarAlloc for Module { fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { ScalarOwned::new_from_bytes::(self.n(), cols, bytes) } - // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { - // Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) - // } } -// impl ZnxAlloc for Scalar { -// type Scalar = i64; +pub trait ScalarToRef { + fn to_ref(&self) -> Scalar<&[u8]>; +} -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { -// Self { -// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), -// } -// } +pub trait ScalarToMut { + fn to_mut(&mut self) -> Scalar<&mut [u8]>; +} -// fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { -// debug_assert_eq!( -// _rows, SCALAR_ZNX_ROWS, -// "rows != {} not supported for Scalar", -// SCALAR_ZNX_ROWS -// ); -// debug_assert_eq!( -// _size, SCALAR_ZNX_SIZE, -// "rows != {} not supported for Scalar", -// SCALAR_ZNX_SIZE -// ); -// module.n() * cols * std::mem::size_of::() -// } -// } +impl ScalarToMut for Scalar> { + fn to_mut(&mut self) -> Scalar<&mut [u8]> { + Scalar { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToMut for Scalar<&mut [u8]> { + fn to_mut(&mut self) -> Scalar<&mut [u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar<&mut [u8]> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl ScalarToRef for Scalar<&[u8]> { + fn to_ref(&self) -> Scalar<&[u8]> { + Scalar { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 09b26d4..c93609f 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,19 +2,16 @@ use std::marker::PhantomData; use crate::ffi::svp; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -pub const SCALAR_ZNX_DFT_ROWS: usize = 1; -pub const SCALAR_ZNX_DFT_SIZE: usize = 1; - -pub struct ScalarZnxDft { +pub struct ScalarZnxDft { data: D, n: usize, cols: usize, _phantom: PhantomData, } -impl ZnxInfos for ScalarZnxDft { +impl ZnxInfos for ScalarZnxDft { fn cols(&self) -> usize { self.cols } @@ -30,20 +27,22 @@ impl ZnxInfos for ScalarZnxDft { fn size(&self) -> usize { 1 } +} +impl ZnxSliceSize for ScalarZnxDft { fn sl(&self) -> usize { self.n() } } -impl DataView for ScalarZnxDft { +impl DataView for ScalarZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for ScalarZnxDft { +impl DataViewMut for ScalarZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -78,20 +77,69 @@ impl>, B: Backend> ScalarZnxDft { _phantom: PhantomData, } } - - // fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { - // debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); - // Self { - // inner: ZnxBase::from_bytes_borrow( - // module.n(), - // SCALAR_ZNX_DFT_ROWS, - // cols, - // SCALAR_ZNX_DFT_SIZE, - // bytes, - // ), - // _phantom: PhantomData, - // } - // } } pub type ScalarZnxDftOwned = ScalarZnxDft, B>; + +pub trait ScalarZnxDftToRef { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; +} + +pub trait ScalarZnxDftToMut { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; +} + +impl ScalarZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index fc56e4e..ea98a57 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -1,26 +1,28 @@ -use crate::ffi::svp::{self, svp_ppol_t}; +use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, FFT64, Module, Scalar, ScalarZnxDft, ScalarZnxDftOwned, VecZnx, VecZnxDft}; +use crate::{ + Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx, + VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize, +}; -pub trait ScalarZnxDftAlloc { +pub trait ScalarZnxDftAlloc { fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } -pub trait ScalarZnxDftOps { - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); - fn svp_apply_dft( - &self, - res: &mut VecZnxDft, - res_col: usize, - a: &ScalarZnxDft, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarToRef; + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxToRef; } impl ScalarZnxDftAlloc for Module { @@ -35,42 +37,38 @@ impl ScalarZnxDftAlloc for Module { fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) } - - // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { - // ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) - // } } -impl ScalarZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { +impl ScalarZnxDftOps for Module { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarToRef, + { unsafe { svp::svp_prepare( self.ptr, - res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t, - a.at_ptr(a_col, 0), + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), ) } } - fn svp_apply_dft( - &self, - res: &mut VecZnxDft, - res_col: usize, - a: &ScalarZnxDft, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); unsafe { svp::svp_apply_dft( self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, - a.at_ptr(a_col, 0) as *const svp_ppol_t, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 09b0051..70d8fb3 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,14 +1,13 @@ use crate::DataView; use crate::DataViewMut; +use crate::ZnxSliceSize; use crate::alloc_aligned; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut, switch_degree}; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use std::{cmp::min, fmt}; -// pub const VEC_ZNX_ROWS: usize = 1; - /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array @@ -20,7 +19,7 @@ use std::{cmp::min, fmt}; /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. pub struct VecZnx { - data: D, + pub data: D, n: usize, cols: usize, size: usize, @@ -42,9 +41,11 @@ impl ZnxInfos for VecZnx { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnx { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } @@ -66,10 +67,6 @@ impl> ZnxView for VecZnx { } impl + AsRef<[u8]>> VecZnx { - pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { - normalize(log_base2k, self, col, carry) - } - /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments @@ -92,11 +89,6 @@ impl + AsRef<[u8]>> VecZnx { .for_each(|x: &mut i64| *x &= mask) } } - - /// Switches degree of from `a.n()` to `self.n()` into `self` - pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { - switch_degree(self, col_a, a, col) - } } impl>> VecZnx { @@ -126,6 +118,17 @@ impl>> VecZnx { } } +impl VecZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + } + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -141,10 +144,12 @@ where data_b[..size].copy_from_slice(&data_a[..size]) } +#[allow(dead_code)] fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } +#[allow(dead_code)] fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); @@ -216,8 +221,16 @@ pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; -impl VecZnx> { - pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> { +pub trait VecZnxToRef { + fn to_ref(&self) -> VecZnx<&[u8]>; +} + +pub trait VecZnxToMut { + fn to_mut(&mut self) -> VecZnx<&mut [u8]>; +} + +impl VecZnxToMut for VecZnx> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, @@ -225,8 +238,10 @@ impl VecZnx> { size: self.size, } } +} - pub fn to_ref(&self) -> VecZnx<&[u8]> { +impl VecZnxToRef for VecZnx> { + fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, @@ -236,10 +251,32 @@ impl VecZnx> { } } -impl VecZnx<&mut [u8]> { - pub fn to_ref(&self) -> VecZnx<&[u8]> { +impl VecZnxToMut for VecZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { - data: &self.data, + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, n: self.n, cols: self.cols, size: self.size, diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index fe67516..8f70272 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,12 +1,9 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::marker::PhantomData; -// const VEC_ZNX_BIG_ROWS: usize = 1; - -/// VecZnxBig is `Backend` dependent, denoted with backend generic `B` -pub struct VecZnxBig { +pub struct VecZnxBig { data: D, n: usize, cols: usize, @@ -14,7 +11,7 @@ pub struct VecZnxBig { _phantom: PhantomData, } -impl ZnxInfos for VecZnxBig { +impl ZnxInfos for VecZnxBig { fn cols(&self) -> usize { self.cols } @@ -30,20 +27,22 @@ impl ZnxInfos for VecZnxBig { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } -impl DataView for VecZnxBig { +impl DataView for VecZnxBig { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for VecZnxBig { +impl DataViewMut for VecZnxBig { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -82,7 +81,7 @@ impl>, B: Backend> VecZnxBig { } } -impl VecZnxBig { +impl VecZnxBig { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, @@ -96,8 +95,16 @@ impl VecZnxBig { pub type VecZnxBigOwned = VecZnxBig, B>; -impl VecZnxBig, B> { - pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { +pub trait VecZnxBigToRef { + fn to_ref(&self) -> VecZnxBig<&[u8], B>; +} + +pub trait VecZnxBigToMut { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; +} + +impl VecZnxBigToMut for VecZnxBig, B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut_slice(), n: self.n, @@ -106,8 +113,10 @@ impl VecZnxBig, B> { _phantom: PhantomData, } } +} - pub fn to_ref(&self) -> VecZnxBig<&[u8], B> { +impl VecZnxBigToRef for VecZnxBig, B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_slice(), n: self.n, @@ -117,3 +126,39 @@ impl VecZnxBig, B> { } } } + +impl VecZnxBigToMut for VecZnxBig<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&mut [u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&[u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index d0e4bd3..185a20c 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,8 +1,11 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big}; +use crate::{ + Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch, + VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big, +}; -pub trait VecZnxBigAlloc { +pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; @@ -39,79 +42,77 @@ pub trait VecZnxBigAlloc { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxBigOps { +pub trait VecZnxBigOps { /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ); + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef; /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_b( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; /// Normalizes `a` and stores the result on `b`. /// @@ -119,28 +120,28 @@ pub trait VecZnxBigOps { /// /// * `log_base2k`: normalization basis. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut R, res_col: usize, - a: &VecZnxBig, + a: &A, a_col: usize, scratch: &mut Scratch, - ); + ) where + R: VecZnxToMut, + A: VecZnxBigToRef; /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism( - &self, - k: i64, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; } pub trait VecZnxBigScratch { @@ -157,29 +158,22 @@ impl VecZnxBigAlloc for Module { VecZnxBig::new_from_bytes(self, cols, size, bytes) } - // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - // VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) - // } - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { bytes_of_vec_znx_big(self, cols, size) } } -impl VecZnxBigOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn vec_znx_big_add( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { +impl VecZnxBigOps for Module { + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -203,13 +197,14 @@ where } } - fn vec_znx_big_add_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -231,15 +226,16 @@ where } } - fn vec_znx_big_sub( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -263,13 +259,14 @@ where } } - fn vec_znx_big_sub_ab_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -291,13 +288,14 @@ where } } - fn vec_znx_big_sub_ba_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -319,15 +317,16 @@ where } } - fn vec_znx_big_sub_small_b( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -351,13 +350,14 @@ where } } - fn vec_znx_big_sub_small_b_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - ) { + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -379,15 +379,16 @@ where } } - fn vec_znx_big_sub_small_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnxBig, - b_col: usize, - ) { + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -411,13 +412,14 @@ where } } - fn vec_znx_big_sub_small_a_inplace( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnx, - a_col: usize, - ) { + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -439,15 +441,16 @@ where } } - fn vec_znx_big_add_small( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -471,7 +474,14 @@ where } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -493,22 +503,28 @@ where } } - fn vec_znx_big_normalize( + fn vec_znx_big_normalize( &self, log_base2k: usize, - res: &mut VecZnx, + res: &mut R, res_col: usize, - a: &VecZnxBig, + a: &A, a_col: usize, scratch: &mut Scratch, - ) { + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. // In the FFT backend the tmp sizes are same but will be different in the NTT backend - // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); // assert_alignement(tmp_bytes.as_ptr()); } @@ -530,14 +546,14 @@ where } } - fn vec_znx_big_automorphism( - &self, - k: i64, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxBig, - a_col: usize, - ) { + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -557,7 +573,12 @@ where } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index a4a3242..66e58cf 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,12 +2,9 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -// const VEC_ZNX_DFT_ROWS: usize = 1; - -// VecZnxDft is `Backend` dependent denoted with generic `B` -pub struct VecZnxDft { +pub struct VecZnxDft { data: D, n: usize, cols: usize, @@ -15,7 +12,7 @@ pub struct VecZnxDft { _phantom: PhantomData, } -impl ZnxInfos for VecZnxDft { +impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols } @@ -31,20 +28,22 @@ impl ZnxInfos for VecZnxDft { fn size(&self) -> usize { self.size } +} +impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { - self.cols() * self.n() + self.n() * self.cols() } } -impl DataView for VecZnxDft { +impl DataView for VecZnxDft { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for VecZnxDft { +impl DataViewMut for VecZnxDft { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } @@ -85,7 +84,7 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; -impl VecZnxDft { +impl VecZnxDft { pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, @@ -97,8 +96,16 @@ impl VecZnxDft { } } -impl VecZnxDft, B> { - pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { +pub trait VecZnxDftToRef { + fn to_ref(&self) -> VecZnxDft<&[u8], B>; +} + +pub trait VecZnxDftToMut { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; +} + +impl VecZnxDftToMut for VecZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { data: self.data.as_mut_slice(), n: self.n, @@ -107,8 +114,10 @@ impl VecZnxDft, B> { _phantom: PhantomData, } } +} - pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { +impl VecZnxDftToRef for VecZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { VecZnxDft { data: self.data.as_slice(), n: self.n, @@ -119,10 +128,34 @@ impl VecZnxDft, B> { } } -impl VecZnxDft<&mut [u8], B> { - pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { +impl VecZnxDftToMut for VecZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { VecZnxDft { - data: &self.data, + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, n: self.n, cols: self.cols, size: self.size, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index e894ef4..83b7c26 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,11 +1,11 @@ use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, VecZnxDftOwned}; -use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; +use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize}; +use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; use std::cmp::min; -pub trait VecZnxDftAlloc { +pub trait VecZnxDftAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; @@ -34,24 +34,26 @@ pub trait VecZnxDftAlloc { fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxDftOps { +pub trait VecZnxDftOps { /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut; - fn vec_znx_idft( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxDft, - a_col: usize, - tmp_bytes: &mut [u8], - ); + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef; - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef; } impl VecZnxDftAlloc for Module { @@ -63,41 +65,34 @@ impl VecZnxDftAlloc for Module { VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } - // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { - // VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) - // } - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { bytes_of_vec_znx_dft(self, cols, size) } } -impl VecZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - fn vec_znx_idft_tmp_a( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &mut VecZnxDft, - a_col: usize, - ) { - let min_size: usize = min(res.size(), a.size()); +impl VecZnxDftOps for Module { + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + let mut res_mut = res.to_mut(); + let mut a_mut = a.to_mut(); + + let min_size: usize = min(res_mut.size(), a_mut.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft_tmp_a( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }) } } @@ -110,61 +105,59 @@ where /// /// # Panics /// If b.cols < a_cols - fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { - let min_size: usize = min(res.size(), a.size()); + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + let mut res_mut = res.to_mut(); + let a_ref = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_dft( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64, - a.at_ptr(a_col, j), + a_ref.at_ptr(a_col, j), 1 as u64, - a.sl() as u64, + a_ref.sl() as u64, ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }); } } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft( - &self, - res: &mut VecZnxBig, - res_col: usize, - a: &VecZnxDft, - a_col: usize, - tmp_bytes: &mut [u8], - ) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= >::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - >::vec_znx_idft_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + let mut res_mut = res.to_mut(); + let a_ref = a.to_ref(); - let min_size: usize = min(res.size(), a.size()); + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes()); + + let min_size: usize = min(res_mut.size(), a_ref.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft( self.ptr, - res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, - a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1 as u64, tmp_bytes.as_mut_ptr(), ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); }); } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index a8edb12..cdabe24 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,6 +1,9 @@ use crate::ffi::vec_znx; -use crate::znx_base::{ZnxInfos, switch_degree}; -use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement}; +use crate::{ + Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use std::cmp::min; pub trait VecZnxAlloc { /// Allocates a new [VecZnx]. @@ -29,73 +32,86 @@ pub trait VecZnxAlloc { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; } -pub trait VecZnxOps { +pub trait VecZnxOps { /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. - fn vec_znx_normalize( - &self, - log_base2k: usize, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - tmp_bytes: &mut [u8], - ); + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut; /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. - fn vec_znx_add( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. - fn vec_znx_sub( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ); + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; /// Subtracts the selected column of `a` from the selected column of `res` inplace. /// /// res[res_col] -= a[a_col] - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` /// /// res[res_col] = a[a_col] - res[res_col] - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; // Negates the selected column of `a` and stores the result in `res_col` of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -103,14 +119,10 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split( - &self, - res: &mut Vec>, - res_col: usize, - a: &VecZnx, - a_col: usize, - buf: &mut VecZnx, - ); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -118,7 +130,15 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + fn switch_degree(&self, r: &mut R, col_b: usize, a: &A, col_a: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; } pub trait VecZnxScratch { @@ -140,27 +160,23 @@ impl VecZnxAlloc for Module { } } -impl VecZnxOps for Module -where - Data: AsRef<[u8]>, - DataMut: AsRef<[u8]> + AsMut<[u8]>, -{ - fn vec_znx_normalize( - &self, - log_base2k: usize, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - tmp_bytes: &mut [u8], - ) { +impl VecZnxOps for Module { + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= ::vec_znx_normalize_tmp_bytes(&self)); - assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, @@ -176,22 +192,44 @@ where } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + unsafe { - let a_ptr: *const VecZnx<_> = a; - Self::vec_znx_normalize(self, log_base2k, a, a_col, &*a_ptr, a_col, tmp_bytes); + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); } } - fn vec_znx_add( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -215,7 +253,14 @@ where } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -237,15 +282,16 @@ where } } - fn vec_znx_sub( - &self, - res: &mut VecZnx, - res_col: usize, - a: &VecZnx, - a_col: usize, - b: &VecZnx, - b_col: usize, - ) { + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -269,7 +315,13 @@ where } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -291,7 +343,13 @@ where } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -313,7 +371,13 @@ where } } - fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -332,14 +396,35 @@ where } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } unsafe { - let a_ref: *const VecZnx<_> = a; - Self::vec_znx_negate(self, a, a_col, a_ref.as_ref().unwrap(), a_col); + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -359,7 +444,11 @@ where } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -378,7 +467,13 @@ where } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -398,7 +493,11 @@ where } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -417,23 +516,24 @@ where } } - fn vec_znx_split( - &self, - res: &mut Vec>, - res_col: usize, - a: &VecZnx, - a_col: usize, - buf: &mut VecZnx, - ) { - let (n_in, n_out) = (a.n(), res[0].n()); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + + let (n_in, n_out) = (a.n(), res[0].to_mut().n()); + + let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size()); debug_assert!( n_out < n_in, "invalid a: output ring degree should be smaller" ); - res[1..].iter().for_each(|bi| { + res[1..].iter_mut().for_each(|bi| { debug_assert_eq!( - bi.n(), + bi.to_mut().n(), n_out, "invalid input a: all VecZnx must have the same degree" ) @@ -441,17 +541,23 @@ where res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, res_col, a, a_col); - self.vec_znx_rotate(-1, buf, 0, a, a_col); + self.switch_degree(bi, res_col, &a, a_col); + self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); } else { - switch_degree(bi, res_col, buf, a_col); - >::vec_znx_rotate_inplace(self, -1, buf, a_col); + self.switch_degree(bi, res_col, &mut buf, a_col); + self.vec_znx_rotate_inplace(-1, &mut buf, a_col); } }) } - fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize) { - let (n_in, n_out) = (res.n(), a[0].n()); + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (res.n(), a[0].to_ref().n()); debug_assert!( n_out < n_in, @@ -459,18 +565,47 @@ where ); a[1..].iter().for_each(|ai| { debug_assert_eq!( - ai.n(), + ai.to_ref().n(), n_out, "invalid input a: all VecZnx must have the same degree" ) }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(res, res_col, ai, a_col); - >::vec_znx_rotate_inplace(self, -1, res, res_col); + self.switch_degree(&mut res, res_col, ai, a_col); + self.vec_znx_rotate_inplace(-1, &mut res, res_col); }); - >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); + self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); + } + + fn switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (a.n(), res.n()); + let (gap_in, gap_out): (usize, usize); + + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + res.zero(); + } + + let size: usize = min(a.size(), res.size()); + + (0..size).for_each(|i| { + izip!( + a.at(a_col, i).iter().step_by(gap_in), + res.at_mut(res_col, i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 9eea5bb..db6a50c 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,6 +1,5 @@ use itertools::izip; use rand_distr::num_traits::Zero; -use std::cmp::min; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. @@ -24,7 +23,9 @@ pub trait ZnxInfos { fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } +} +pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. fn sl(&self) -> usize; @@ -129,33 +130,6 @@ where impl ZnxZero for T where T: ZnxViewMut {} // impl ZnxRsh for T where T: ZnxZero {} -pub fn switch_degree + ZnxZero, D: ZnxView>( - b: &mut DMut, - col_b: usize, - a: &D, - col_a: usize, -) { - let (n_in, n_out) = (a.n(), b.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - b.zero(); - } - - let size: usize = min(a.size(), b.size()); - - (0..size).for_each(|i| { - izip!( - a.at(col_a, i).iter().step_by(gap_in), - b.at_mut(col_b, i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); -} - use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; use crate::Scratch;