diff --git a/base2k/.vscode/settings.json b/base2k/.vscode/settings.json index eecbcdc..c38916e 100644 --- a/base2k/.vscode/settings.json +++ b/base2k/.vscode/settings.json @@ -4,5 +4,8 @@ "plaintext": false, "markdown": false, "scminput": false + }, + "files.associations": { + "random": "c" } } \ No newline at end of file diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 07fe1c6..2f08633 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -13,7 +13,8 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut tmp_bytes_norm: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -38,9 +39,10 @@ fn main() { module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft( + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow( 1, // Number of columns ct.size(), // Number of polynomials per column + &mut tmp_bytes_dft, ); // Applies DFT(ct[1]) * DFT(s) @@ -68,7 +70,7 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - m.normalize(log_base2k, &mut carry); + m.normalize(log_base2k, 0, &mut tmp_bytes_norm); // m - BIG(ct[1] * s) module.vec_znx_big_sub_small_a_inplace( @@ -81,9 +83,12 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) - &buf_big, 0, // Selects the first column of buf_big - &mut carry, + log_base2k, + &mut ct, + 0, // Selects the first column of ct (ct[0]) + &buf_big, + 0, // Selects the first column of buf_big + &mut tmp_bytes_norm, ); // Add noise to ct[0] @@ -120,7 +125,7 @@ fn main() { // m + e <- BIG(ct[1] * s + ct[0]) let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs deleted file mode 100644 index e565be1..0000000 --- a/base2k/examples/vector_matrix_product.rs +++ /dev/null @@ -1,59 +0,0 @@ -use base2k::{ - Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxInfos, ZnxLayout, alloc_aligned, -}; - -fn main() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let log_base2k: usize = 15; - let limbs_vec: usize = 5; - let log_k: usize = log_base2k * limbs_vec - 5; - - let rows_mat: usize = limbs_vec; - let limbs_mat: usize = limbs_vec + 1; - - // Maximum size of the byte scratch needed - let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat) - | module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat); - - let mut buf: Vec = alloc_aligned(tmp_bytes); - - let mut a_values: Vec = vec![i64::default(); n]; - a_values[1] = (1 << log_base2k) + 1; - - let mut a: VecZnx = module.new_vec_znx(1, limbs_vec); - a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); - a.normalize(log_base2k, &mut buf); - - a.print(n); - println!(); - - let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); - - (0..a.size()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); - tmp.at_limb_mut(row_i)[1] = 1 as i64; - module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); - }); - - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); - module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); - - let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0); - - let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); - - let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - - res.print(n); - - module.free(); - - println!("{:?}", values_res) -} diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs new file mode 100644 index 0000000..710744e --- /dev/null +++ b/base2k/examples/vmp.rs @@ -0,0 +1,78 @@ +use base2k::{ + Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, + ZnxInfos, ZnxLayout, alloc_aligned, +}; + +fn main() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let log_base2k: usize = 15; + + let a_cols: usize = 2; + let a_size: usize = 5; + + let log_k: usize = log_base2k * a_size - 5; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = 2; + let mat_size: usize = a_size + 1; + + let mut tmp_bytes_vmp: Vec = alloc_aligned( + module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + | module.vmp_apply_dft_tmp_bytes( + a_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ), + ); + + let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); + + let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + let mut values: Vec = vec![i64::default(); n]; + values[1 + i] = (1 << log_base2k) + 1; + a.encode_vec_i64(i, log_base2k, log_k, &values, 32); + a.normalize(log_base2k, i, &mut tmp_bytes_vmp); + a.print(n, i); + println!(); + }); + + let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + (0..a.size()).for_each(|row_i| { + let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + (0..mat_cols_out).for_each(|j| { + tmp.at_mut(j, row_i)[1 + j] = 1 as i64; + }); + (0..mat_cols_in).for_each(|j| { + module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); + }) + }); + + let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); + module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); + + let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); + let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); + + let mut values_res: Vec = vec![i64::default(); n]; + res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); + res.print(n, i); + println!(); + println!("{:?}", values_res); + println!(); + }); + + module.free(); +} diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index e3d3247..8135d85 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit e3d3247335faccf2b6361213c354cd61b958325e +Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 104bd4b..470adcc 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::{Backend, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; @@ -10,6 +10,8 @@ use std::marker::PhantomData; /// See the trait [MatZnxDftOps] for additional information. pub struct MatZnxDft { pub inner: ZnxBase, + pub cols_in: usize, + pub cols_out: usize, _marker: PhantomData, } @@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft { type Scalar = f64; } -impl ZnxAlloc for MatZnxDft { - type Scalar = u8; +impl MatZnxDft { + pub fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let bytes: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self::from_bytes(module, rows, cols_in, cols_out, size, bytes) + } - fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + pub fn from_bytes(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec) -> Self { + let mut mat: MatZnxDft = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes); + mat.znx_mut().data = bytes; + mat + } + + pub fn from_bytes_borrow( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> Self { + debug_assert_eq!( + bytes.len(), + Self::bytes_of(module, rows, cols_in, cols_out, size) + ); Self { - inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes), + inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), + cols_in: cols_in, + cols_out: cols_out, _marker: PhantomData, } } - fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize { - unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } + pub fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out } } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 85177aa..48c3834 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -1,8 +1,9 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxLayout}; -use crate::{Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxAlloc, assert_alignement}; +use crate::{ + Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned, +}; /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [MatZnxDft]. @@ -13,44 +14,45 @@ pub trait MatZnxDftOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft; - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; - fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft; + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDft; - fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft; + fn new_mat_znx_dft_from_bytes_borrow( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> MatZnxDft; - /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. - /// - /// # Arguments - /// - /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; - - /// Prepares a [MatZnxDft] from a contiguous array of [i64]. - /// The helper struct [Matrix3D] can be used to contruct and populate - /// the appropriate contiguous array. - /// - /// # Arguments - /// - /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. - /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] + fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments /// /// * `b`: [MatZnxDft] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. - /// * `row_i`: the index of the row to prepare. + /// * `row_i`: the row of the [MatZnxDft] to prepare. + /// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft]. /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]); + + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// @@ -59,7 +61,15 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); + fn vmp_extract_row( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &MatZnxDft, + b_row: usize, + b_col_in: usize, + tmp_bytes: &mut [u8], + ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// @@ -70,7 +80,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// @@ -79,7 +89,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -89,7 +99,15 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// @@ -117,32 +135,6 @@ pub trait MatZnxDftOps { /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// /// # Arguments @@ -151,7 +143,17 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnxDft]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -179,308 +181,385 @@ pub trait MatZnxDftOps { /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols, size) + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols_in, cols_out, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { - MatZnxDft::::bytes_of(self, rows, cols, size) + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxDft::::bytes_of(self, rows, cols_in, cols_out, size) } - fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec) -> MatZnxDft { - MatZnxDft::::from_bytes(self, rows, cols, size, bytes) + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDft { + MatZnxDft::::from_bytes(self, rows, cols_in, cols_out, size, bytes) } - fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft { - MatZnxDft::::from_bytes_borrow(self, rows, cols, size, bytes) + fn new_mat_znx_dft_from_bytes_borrow( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: &mut [u8], + ) -> MatZnxDft { + MatZnxDft::::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes) } - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } + fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + self.bytes_of_vec_znx_dft(cols_out, size) } - fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], tmp_bytes: &mut [u8]) { + fn vmp_prepare_row(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { - assert_eq!(a.len(), b.n() * b.poly_count()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + a.cols(), + b.cols_out(), + "a.cols(): {} != b.cols_out(): {}", + a.cols(), + b.cols_out() + ); + assert!( + b_row < b.rows(), + "b_row: {} >= b.rows(): {}", + b_row, + b.rows() + ); + assert!( + b_col_in < b.cols_in(), + "b_col_in: {} >= b.cols_in(): {}", + b_col_in, + b.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); + assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size())); + assert!(is_aligned(tmp_bytes.as_ptr())) } - unsafe { - vmp::vmp_prepare_contiguous( - self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, - a.as_ptr(), - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), + + let cols_out: usize = a.cols(); + let a_size: usize = a.size(); + + let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); + + let mut a_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft); + (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); + + Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); + } + + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes() + } + + fn vmp_extract_row( + &self, + log_base2k: usize, + b: &mut VecZnx, + a: &MatZnxDft, + a_row: usize, + a_col_in: usize, + tmp_bytes: &mut [u8], + ) { + #[cfg(debug_assertions)] + { + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + b.cols(), + a.cols_out(), + "b.cols(): {} != a.cols_out(): {}", + b.cols(), + a.cols_out() + ); + assert!( + a_row < a.rows(), + "a_row: {} >= a.rows(): {}", + a_row, + a.rows() + ); + assert!( + a_col_in < a.cols_in(), + "a_col_in: {} >= a.cols_in(): {}", + a_col_in, + a.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); + assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size())); + assert!(is_aligned(tmp_bytes.as_ptr())) + } + + let cols_out: usize = b.cols(); + let size: usize = b.size(); + + let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); + let mut b_dft: VecZnxDft = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft); + Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); + let mut b_big: VecZnxBig = b_dft.alias_as_vec_znx_big(); + (0..cols_out).for_each(|i| { + self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i); + self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes); + }); + } + + fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + #[cfg(debug_assertions)] + { + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + a.cols(), + b.cols_out(), + "a.cols(): {} != b.cols_out(): {}", + a.cols(), + b.cols_out() + ); + assert!( + b_row < b.rows(), + "b_row: {} >= b.rows(): {}", + b_row, + b.rows() + ); + assert!( + b_col_in < b.cols_in(), + "b_col_in: {} >= b.cols_in(): {}", + b_col_in, + b.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() ); } - } - fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.size() * self.n() * b.cols()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_row( - self.ptr, - b.as_mut_ptr() as *mut vmp::vmp_pmat_t, - a.as_ptr(), - row_i as u64, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_extract_row( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_big_t, - a.as_ptr() as *const vmp::vmp_pmat_t, - row_i as u64, - a.rows() as u64, - (a.size() * a.cols()) as u64, - ); - } - } - - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); - } unsafe { vmp::vmp_prepare_row_dft( self.ptr, b.as_mut_ptr() as *mut vmp::vmp_pmat_t, a.as_ptr() as *const vec_znx_dft_t, - row_i as u64, - b.rows() as u64, - b.size() as u64, + (b_row * b.cols_in() + b_col_in) as u64, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, ); } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { #[cfg(debug_assertions)] { - assert_eq!(a.n(), b.n()); - assert_eq!(a.size(), b.size()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + b.cols(), + a.cols_out(), + "b.cols(): {} != a.cols_out(): {}", + b.cols(), + a.cols_out() + ); + assert!( + a_row < a.rows(), + "a_row: {} >= a.rows(): {}", + a_row, + a.rows() + ); + assert!( + a_col_in < a.cols_in(), + "a_col_in: {} >= a.cols_in(): {}", + a_col_in, + a.cols_in() + ); + assert_eq!( + b.size(), + a.size(), + "b.size(): {} != a.size(): {}", + b.size(), + a.size() + ); } unsafe { vmp::vmp_extract_row_dft( self.ptr, b.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp::vmp_pmat_t, - row_i as u64, - a.rows() as u64, - a.size() as u64, + (a_row * a.cols_in() + a_col_in) as u64, + (a.rows() * a.cols_in()) as u64, + (a.size() * a.cols_out()) as u64, ); } } - fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { + fn vmp_apply_dft_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, res_size as u64, a_size as u64, - b_rows as u64, - b_size as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, ) as usize } } fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); + debug_assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_tmp_bytes( + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); #[cfg(debug_assertions)] { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_tmp_bytes( + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft( self.ptr, c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, + (c.size() * c.cols()) as u64, a.as_ptr(), - a.size() as u64, - (a.n() * a.cols()) as u64, + (a.size() * a.cols()) as u64, + a.n() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.size()) as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + res_cols: usize, + res_size: usize, + a_size: usize, + a_cols: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( self.ptr, - res_size as u64, - a_size as u64, - gct_rows as u64, - gct_size as u64, + (res_size * res_cols) as u64, + (a_size * a_cols) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, ) as usize } } fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + assert!( + tmp_bytes.len() + >= self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size() + ) + ); assert_alignement(tmp_bytes.as_ptr()); } unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, + c.poly_count() as u64, a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, + a.poly_count() as u64, b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_add( - &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - tmp_bytes: &mut [u8], - ) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft_add( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.size() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.size() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - b.size() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - b.as_mut_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - b.as_ptr() as *mut vec_znx_dft_t, - b.size() as u64, - a.as_ptr() as *const vmp::vmp_pmat_t, - a.rows() as u64, - a.size() as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ) } @@ -497,38 +576,52 @@ mod tests { #[test] fn vmp_prepare_row_dft() { - let module: Module = Module::::new(32); - let vpmat_rows: usize = 4; - let vpmat_size: usize = 5; + let module: Module = Module::::new(16); let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_size); - let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); - let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); + let mat_rows: usize = 4; + let mat_cols_in: usize = 2; + let mat_cols_out: usize = 2; + let mat_size: usize = 5; + let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut a_big: VecZnxBig = module.new_vec_znx_big(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut vmpmat_0: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut vmpmat_1: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); + let mut tmp_bytes: Vec = + alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - for row_i in 0..vpmat_rows { - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); - module.vec_znx_dft(&mut a_dft, 0, &a, 0); - module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); + for col_in in 0..mat_cols_in { + for row_i in 0..mat_rows { + let mut source: Source = Source::new([0u8; 32]); - // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); + (0..mat_cols_out).for_each(|col_out| { + module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source); + module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); + }); - // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); - assert_eq!(a_dft.raw(), b_dft.raw()); + module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes); - // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); - assert_eq!(a_big.raw(), b_big.raw()); + // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) + module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); + assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); + + // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) + module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in); + assert_eq!(a_dft.raw(), b_dft.raw()); + + // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) + module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes); + + (0..mat_cols_out).for_each(|col_out| { + module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); + module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes); + }); + + assert_eq!(a.raw(), b.raw()); + } } module.free(); diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index ffb54b5..6fdb991 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -28,6 +28,7 @@ impl ZnxAlloc for ScalarZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size)); Self { inner: ZnxBase::from_bytes_borrow( module.n(), @@ -61,6 +62,6 @@ impl ZnxLayout for ScalarZnxDft { impl ZnxSliceSize for ScalarZnxDft { fn sl(&self) -> usize { - self.n() + self.n() * self.cols() } } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 125f32e..544c096 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -3,7 +3,7 @@ use crate::Module; use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree}; use std::cmp::min; pub const VEC_ZNX_ROWS: usize = 1; @@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx { type Scalar = i64; } -impl ZnxBasics for VecZnx {} +impl ZnxZero for VecZnx {} + +impl ZnxRsh for VecZnx {} impl ZnxAlloc for VecZnx { type Scalar = i64; @@ -84,7 +86,7 @@ impl VecZnx { /// /// * `log_base2k`: the base two logarithm of the coefficients decomposition. /// * `k`: the number of bits of precision to drop. - pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { + pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) { if k == 0 { return; } @@ -101,7 +103,7 @@ impl VecZnx { if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_limb_mut(self.size() - 1) + self.at_mut(col, self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -111,8 +113,8 @@ impl VecZnx { copy_vec_znx_from(self, a); } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) + pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) { + normalize(log_base2k, self, col, carry) } pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { @@ -120,26 +122,25 @@ impl VecZnx { } // Prints the first `n` coefficients of each limb - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); } } -fn normalize_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() } -fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { +fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let cols: usize = a.cols(); debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n, cols), - "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", + tmp_bytes.len() >= normalize_tmp_bytes(n), + "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})", tmp_bytes.len(), n, - cols, ); + #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()) @@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); (0..a.size()).rev().for_each(|i| { znx::znx_normalize( - (n * cols) as u64, + n as u64, log_base2k as u64, - a.at_mut_ptr(0, i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), - a.at_mut_ptr(0, i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), ) }); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index cbcd4b9..5ba7dde 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_big; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; use crate::{Backend, FFT64, Module, NTT120}; use std::marker::PhantomData; @@ -26,6 +26,7 @@ impl ZnxAlloc for VecZnxBig { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); VecZnxBig { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), _marker: PhantomData, @@ -50,24 +51,24 @@ impl ZnxLayout for VecZnxBig { type Scalar = i128; } -impl ZnxBasics for VecZnxBig {} +impl ZnxZero for VecZnxBig {} impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.n() + self.n() * self.cols() } } impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { - self.n() * 4 + self.n() * 4 * self.cols() } } -impl ZnxBasics for VecZnxBig {} +impl ZnxZero for VecZnxBig {} impl VecZnxBig { - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); } } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 9c6feee..8be526e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,4 +1,4 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; +use crate::ffi::vec_znx; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; @@ -171,14 +171,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_add( + vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, - b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, + a.sl() as u64, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -207,14 +210,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, - b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, + a.sl() as u64, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -250,12 +256,14 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub_small_b( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, @@ -287,15 +295,17 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_sub_small_a( + vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, + res.sl() as u64, a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t, + b.at_ptr(b_col, 0), b.size() as u64, + b.sl() as u64, ) } } @@ -324,12 +334,14 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_add_small( + vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, @@ -365,14 +377,15 @@ impl VecZnxBigOps for Module { assert_alignement(tmp_bytes.as_ptr()); } unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( + vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); } @@ -385,13 +398,15 @@ impl VecZnxBigOps for Module { assert_eq!(res.n(), self.n()); } unsafe { - vec_znx_big::vec_znx_big_automorphism( + vec_znx::vec_znx_automorphism( self.ptr, k, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, + res.at_mut_ptr(res_col, 0), res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, + res.sl() as u64, + a.at_ptr(a_col, 0), a.size() as u64, + a.sl() as u64, ) } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index a9dd378..b187645 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,5 +1,5 @@ use crate::ffi::vec_znx_dft; -use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero}; use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; @@ -26,6 +26,7 @@ impl ZnxAlloc for VecZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); Self { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, @@ -46,6 +47,8 @@ impl ZnxLayout for VecZnxDft { type Scalar = f64; } +impl ZnxZero for VecZnxDft {} + impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { self.n() @@ -53,8 +56,8 @@ impl ZnxSliceSize for VecZnxDft { } impl VecZnxDft { - pub fn print(&self, n: usize) { - (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + pub fn print(&self, n: usize, col: usize) { + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); } } @@ -63,6 +66,10 @@ impl VecZnxDft { /// The returned [VecZnxBig] shares the backing array /// with the original [VecZnxDft]. pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { + assert!( + self.data().len() == 0, + "cannot alias VecZnxDft into VecZnxBig if it owns the data" + ); VecZnxBig:: { inner: ZnxBase { data: Vec::new(), diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 57b3777..679abce 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -4,7 +4,8 @@ use crate::znx_base::ZnxAlloc; use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxLayout; use crate::znx_base::ZnxSliceSize; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement}; +use std::cmp::min; pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. @@ -77,19 +78,21 @@ impl VecZnxDftOps for Module { } fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(res.poly_count(), a.poly_count()); - } + let min_size: usize = min(res.size(), a.size()); unsafe { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, - res.size() as u64, - a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, - a.size() as u64, - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }) } } @@ -102,15 +105,22 @@ impl VecZnxDftOps for Module { /// # Panics /// If b.cols < a_cols fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + let min_size: usize = min(res.size(), a.size()); + unsafe { - vec_znx_dft::vec_znx_dft( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, - res.size() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_dft( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + a.at_ptr(a_col, j), + 1 as u64, + a.sl() as u64, + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }); } } @@ -126,15 +136,23 @@ impl VecZnxDftOps for Module { ); assert_alignement(tmp_bytes.as_ptr()) } + + let min_size: usize = min(res.size(), a.size()); + unsafe { - vec_znx_dft::vec_znx_idft( - self.ptr, - res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, - res.size() as u64, - a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t, - a.size() as u64, - tmp_bytes.as_mut_ptr(), - ) + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft( + self.ptr, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1 as u64, + tmp_bytes.as_mut_ptr(), + ) + }); + (min_size..res.size()).for_each(|j| { + res.zero_at(res_col, j); + }); } } } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 64ad85f..4cacb70 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -22,6 +22,33 @@ pub struct ZnxBase { pub ptr: *mut u8, } +impl ZnxBase { + pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); + res.data = bytes; + res + } + + pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert_eq!(n & (n - 1), 0, "n must be a power of two"); + assert!(n > 0, "n must be greater than 0"); + assert!(rows > 0, "rows must be greater than 0"); + assert!(cols > 0, "cols must be greater than 0"); + assert!(size > 0, "size must be greater than 0"); + } + Self { + n: n, + rows: rows, + cols: cols, + size: size, + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + } + } +} + pub trait GetZnxBase { fn znx(&self) -> &ZnxBase; fn znx_mut(&mut self) -> &mut ZnxBase; @@ -52,10 +79,12 @@ pub trait ZnxInfos: GetZnxBase { self.znx().size } + /// Returns the underlying raw bytes array. fn data(&self) -> &[u8] { &self.znx().data } + /// Returns a pointer to the underlying raw bytes array. fn ptr(&self) -> *mut u8 { self.znx().ptr } @@ -72,33 +101,6 @@ pub trait ZnxSliceSize { fn sl(&self) -> usize; } -impl ZnxBase { - pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); - res.data = bytes; - res - } - - pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(n & (n - 1), 0, "n must be a power of two"); - assert!(n > 0, "n must be greater than 0"); - assert!(rows > 0, "rows must be greater than 0"); - assert!(cols > 0, "cols must be greater than 0"); - assert!(size > 0, "size must be greater than 0"); - } - Self { - n: n, - rows: rows, - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - } - } -} - pub trait ZnxAlloc where Self: Sized + ZnxInfos, @@ -148,25 +150,25 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } } - /// Returns a non-mutable pointer starting at the (i, j)-th small polynomial. + /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } - let offset = self.n() * (j * self.cols() + i); + let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } } - /// Returns a mutable pointer starting at the (i, j)-th small polynomial. + /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { assert!(i < self.cols()); assert!(j < self.size()); } - let offset = self.n() * (j * self.cols() + i); + let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } } @@ -179,16 +181,6 @@ pub trait ZnxLayout: ZnxInfos { fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } } - - /// Returns non-mutable reference to the i-th limb. - fn at_limb(&self, j: usize) -> &[Self::Scalar] { - unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) } - } - - /// Returns mutable reference to the i-th limb. - fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } - } } use std::convert::TryFrom; @@ -221,14 +213,17 @@ impl IntegerType for i128 { const BITS: u32 = 128; } -pub trait ZnxBasics: ZnxLayout +pub trait ZnxZero: ZnxLayout where Self: Sized, - Self::Scalar: IntegerType, { fn zero(&mut self) { unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::()); + std::ptr::write_bytes( + self.as_mut_ptr(), + 0, + self.n() * size_of::() * self.poly_count(), + ); } } @@ -241,13 +236,19 @@ where ); } } +} - fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) +pub trait ZnxRsh: ZnxLayout + ZnxZero +where + Self: Sized, + Self::Scalar: IntegerType, +{ + fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { + rsh(k, log_base2k, self, col, carry) } } -pub fn rsh(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8]) +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) where V::Scalar: IntegerType, { @@ -258,7 +259,7 @@ where #[cfg(debug_assertions)] { assert!( - tmp_bytes.len() >= rsh_tmp_bytes::(n, cols), + tmp_bytes.len() >= rsh_tmp_bytes::(n), "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", tmp_bytes.len() / size_of::(), n, @@ -291,7 +292,7 @@ where let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); (steps..size).for_each(|i| { - izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { + izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { *xi += *ci << log_base2k_t; *ci = get_base_k_carry(*xi, shift); *xi = (*xi - *ci) >> k_rem_t; @@ -305,11 +306,11 @@ fn get_base_k_carry(x: T, shift: T) -> T { (x << shift) >> shift } -pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { - n * cols * std::mem::size_of::() +pub fn rsh_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() } -pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) +pub fn switch_degree(b: &mut T, col_b: usize, a: &T, col_a: usize) where ::Scalar: IntegerType, {