From 6f7b93c7ca7d1234ed3cc04007c608e28c6dd5d7 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 30 Apr 2025 13:43:18 +0200 Subject: [PATCH] wip major refactoring (compiles & all test + example passing) --- base2k/examples/rlwe_encrypt.rs | 18 +- base2k/examples/vector_matrix_product.rs | 4 +- base2k/src/encoding.rs | 12 +- base2k/src/lib.rs | 18 +- base2k/src/mat_znx_dft.rs | 257 +++++++--------- base2k/src/sampling.rs | 8 +- base2k/src/scalar_znx_dft.rs | 56 ++-- base2k/src/stats.rs | 3 +- base2k/src/vec_znx.rs | 156 +++------- base2k/src/vec_znx_big.rs | 132 +++----- base2k/src/vec_znx_big_ops.rs | 206 ++++++------- base2k/src/vec_znx_dft.rs | 373 +++-------------------- base2k/src/vec_znx_dft_ops.rs | 140 +++++++++ base2k/src/vec_znx_ops.rs | 15 +- base2k/src/{commons.rs => znx_base.rs} | 122 +++++++- rlwe/Cargo.toml | 2 - rlwe/src/automorphism.rs | 8 +- rlwe/src/keys.rs | 2 +- 18 files changed, 662 insertions(+), 870 deletions(-) create mode 100644 base2k/src/vec_znx_dft_ops.rs rename base2k/src/{commons.rs => znx_base.rs} (67%) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index ee2bd02..0f75ef3 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -23,10 +23,10 @@ fn main() { s.fill_ternary_prob(0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: ScalarZnxDft = module.new_svp_ppol(); + let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(); - // s_ppol <- DFT(s) - module.svp_prepare(&mut s_ppol, &s); + // s_dft <- DFT(s) + module.svp_prepare(&mut s_dft, &s); // Allocates a VecZnx with two columns: ct=(0, 0) let mut ct: VecZnx = module.new_vec_znx( @@ -46,16 +46,17 @@ fn main() { // Applies DFT(ct[1]) * DFT(s) module.svp_apply_dft( &mut buf_dft, // DFT(ct[1] * s) - &s_ppol, // DFT(s) + 0, // Selects the first column of res + &s_dft, // DFT(s) &ct, 1, // Selects the second column of ct ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); + let mut buf_big: VecZnxBig = buf_dft.alias_as_vec_znx_big(); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column let mut m: VecZnx = module.new_vec_znx( @@ -103,13 +104,14 @@ fn main() { // DFT(ct[1] * s) module.svp_apply_dft( &mut buf_dft, - &s_ppol, + 0, // Selects the first column of res. + &s_dft, &ct, 1, // Selects the second column of ct (ct[1]) ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 2f4b1fb..e565be1 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -42,8 +42,8 @@ fn main() { let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs_mat); module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); - let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); + let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0); let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 7f8a0cc..b7d014d 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,6 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{VecZnx, ZnxInfos, ZnxLayout}; +use crate::znx_base::ZnxLayout; +use crate::{VecZnx, znx_base::ZnxInfos}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -262,7 +263,10 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i #[cfg(test)] mod tests { - use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout}; + use crate::{ + Encoding, FFT64, Module, VecZnx, VecZnxOps, + znx_base::{ZnxInfos, ZnxLayout}, + }; use itertools::izip; use sampling::source::Source; @@ -273,7 +277,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, size); + let mut a: VecZnx = module.new_vec_znx(2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -295,7 +299,7 @@ mod tests { let log_base2k: usize = 17; let size: usize = 5; let log_k: usize = size * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, size); + let mut a: VecZnx = module.new_vec_znx(2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7a8a3f8..f57e482 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -1,4 +1,3 @@ -pub mod commons; pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports @@ -12,9 +11,10 @@ pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_big_ops; pub mod vec_znx_dft; +pub mod vec_znx_dft_ops; pub mod vec_znx_ops; +pub mod znx_base; -pub use commons::*; pub use encoding::*; pub use mat_znx_dft::*; pub use module::*; @@ -26,7 +26,9 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_big_ops::*; pub use vec_znx_dft::*; +pub use vec_znx_dft_ops::*; pub use vec_znx_ops::*; +pub use znx_base::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; @@ -110,14 +112,8 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } -// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is -// at least as bit as DEFAULTALIGN / std::mem::size_of::(). +/// Allocates an aligned of size equal to the smallest multiple +/// of [DEFAULTALIGN] that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::( - std::cmp::max( - size.next_power_of_two(), - DEFAULTALIGN / std::mem::size_of::(), - ), - DEFAULTALIGN, - ) + alloc_aligned_custom::(size + (size % DEFAULTALIGN), DEFAULTALIGN) } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index b40ed71..9b5e2ca 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,103 +1,75 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; use std::marker::PhantomData; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft]. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. /// -/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. -/// See the trait [VmpPMatOps] for additional information. +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. pub struct MatZnxDft { - /// Raw data, is empty if borrowing scratch space. - data: Vec, - /// Pointer to data. Can point to scratch space. - ptr: *mut u8, - /// The ring degree of each polynomial. - n: usize, - /// Number of rows - rows: usize, - /// Number of cols - cols: usize, - /// The number of small polynomials - size: usize, + pub inner: ZnxBase, _marker: PhantomData, } -impl ZnxInfos for MatZnxDft { - fn n(&self) -> usize { - self.n +impl GetZnxBase for MatZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn rows(&self) -> usize { - self.rows - } - - fn cols(&self) -> usize { - self.cols - } - - fn size(&self) -> usize { - self.size + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } } -impl MatZnxDft { - fn new(module: &Module, rows: usize, cols: usize, size: usize) -> MatZnxDft { - let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, size)); - let ptr: *mut u8 = data.as_mut_ptr(); - MatZnxDft:: { - data: data, - ptr: ptr, - n: module.n(), - rows: rows, - cols: cols, - size: size, +impl ZnxInfos for MatZnxDft {} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() + } +} + +impl ZnxLayout for MatZnxDft { + type Scalar = f64; +} + +impl ZnxAlloc for MatZnxDft { + type Scalar = u8; + + fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + Self { + inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes), _marker: PhantomData, } } - pub fn as_ptr(&self) -> *const u8 { - self.ptr + fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } } +} - pub fn as_mut_ptr(&self) -> *mut u8 { - self.ptr - } - - pub fn borrowed(&self) -> bool { - self.data.len() == 0 - } - - /// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat]. - pub fn raw(&self) -> &[f64] { - let ptr: *const f64 = self.ptr as *const f64; - let size: usize = self.n() * self.poly_count(); - unsafe { &std::slice::from_raw_parts(ptr, size) } - } - - /// Returns a mutable reference of to the entire contiguous array of the [VmpPMat]. - pub fn raw_mut(&self) -> &mut [f64] { - let ptr: *mut f64 = self.ptr as *mut f64; - let size: usize = self.n() * self.poly_count(); - unsafe { std::slice::from_raw_parts_mut(ptr, size) } - } - - /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. +impl MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. /// /// # Arguments /// /// * `row`: row index (i). /// * `col`: col index (j). - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = alloc_aligned(self.n); + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); - if self.n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]); + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); } else { - (0..self.n >> 3).for_each(|blk| { + (0..n >> 3).for_each(|blk| { res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); }); } @@ -105,6 +77,7 @@ impl MatZnxDft { res } + #[allow(dead_code)] fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); let nsize: usize = self.size(); @@ -117,11 +90,11 @@ impl MatZnxDft { } /// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [VmpPMat]. +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. pub trait MatZnxDftOps { fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; - /// Allocates a new [VmpPMat] with the given number of rows and columns. + /// Allocates a new [MatZnxDft] with the given number of rows and columns. /// /// # Arguments /// @@ -129,83 +102,83 @@ pub trait MatZnxDftOps { /// * `size`: number of size (number of size of each [VecZnxDft]). fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; - /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. + /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. /// /// # Arguments /// - /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - /// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. + /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. + /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous]. fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; - /// Prepares a [VmpPMat] from a contiguous array of [i64]. + /// Prepares a [MatZnxDft] from a contiguous array of [i64]. /// The helper struct [Matrix3D] can be used to contruct and populate /// the appropriate contiguous array. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft]. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft, a: &[i64], buf: &mut [u8]); - /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. + /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. /// * `row_i`: the index of the row to prepare. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); - /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// /// # Arguments /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. + /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &MatZnxDft, row_i: usize); - /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. + /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// /// # Arguments /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the [VecZnxDft] to encode on the [VmpPMat]. + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. /// * `row_i`: the index of the row to prepare. /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, a: &VecZnxDft, row_i: usize); - /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// /// # Arguments /// - /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. + /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// /// # Arguments /// /// * `c_size`: number of size of the output [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `size`: number of size of the input [VmpPMat]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -221,17 +194,17 @@ pub trait MatZnxDftOps { /// /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -247,28 +220,28 @@ pub trait MatZnxDftOps { /// /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, buf: &mut [u8]); - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// /// # Arguments /// /// * `c_size`: number of size of the output [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `size`: number of size of the input [VmpPMat]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -284,18 +257,18 @@ pub trait MatZnxDftOps { /// /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -311,18 +284,18 @@ pub trait MatZnxDftOps { /// /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, buf: &mut [u8]); - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. /// - /// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and /// `j` size, the output is a [VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. @@ -337,8 +310,8 @@ pub trait MatZnxDftOps { /// # Arguments /// /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// * `a`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, buf: &mut [u8]); } @@ -404,7 +377,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_extract_row( self.ptr, - b.ptr as *mut vec_znx_big_t, + b.as_mut_ptr() as *mut vec_znx_big_t, a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, @@ -423,7 +396,7 @@ impl MatZnxDftOps for Module { vmp::vmp_prepare_row_dft( self.ptr, b.as_mut_ptr() as *mut vmp_pmat_t, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, b.size() as u64, @@ -440,7 +413,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_extract_row_dft( self.ptr, - b.ptr as *mut vec_znx_dft_t, + b.as_mut_ptr() as *mut vec_znx_dft_t, a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, @@ -470,7 +443,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, a.as_ptr(), a.size() as u64, @@ -492,7 +465,7 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_add( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, a.as_ptr(), a.size() as u64, @@ -526,9 +499,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, @@ -553,9 +526,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft_add( self.ptr, - c.ptr as *mut vec_znx_dft_t, + c.as_mut_ptr() as *mut vec_znx_dft_t, c.size() as u64, - a.ptr as *const vec_znx_dft_t, + a.as_ptr() as *const vec_znx_dft_t, a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, @@ -574,9 +547,9 @@ impl MatZnxDftOps for Module { unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, - b.ptr as *mut vec_znx_dft_t, + b.as_mut_ptr() as *mut vec_znx_dft_t, b.size() as u64, - b.ptr as *mut vec_znx_dft_t, + b.as_ptr() as *mut vec_znx_dft_t, b.size() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, @@ -591,7 +564,7 @@ impl MatZnxDftOps for Module { mod tests { use crate::{ FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxLayout, alloc_aligned, + alloc_aligned, znx_base::ZnxLayout, }; use sampling::source::Source; @@ -614,7 +587,7 @@ mod tests { for row_i in 0..vpmat_rows { let mut source: Source = Source::new([0u8; 32]); module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); - module.vec_znx_dft(&mut a_dft, &a); + module.vec_znx_dft(&mut a_dft, 0, &a, 0); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) @@ -627,7 +600,7 @@ mod tests { // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); + module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); assert_eq!(a_big.raw(), b_big.raw()); } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 5261207..b52c4db 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,4 +1,4 @@ -use crate::{Backend, Module, VecZnx, ZnxLayout}; +use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; @@ -106,7 +106,7 @@ impl Sampling for Module { #[cfg(test)] mod tests { use super::Sampling; - use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout}; + use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout}; use sampling::source::Source; #[test] @@ -120,7 +120,7 @@ mod tests { let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, size); + let mut a: VecZnx = module.new_vec_znx(cols, size); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { @@ -154,7 +154,7 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, size); + let mut a: VecZnx = module.new_vec_znx(cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 474135b..07e156d 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -2,9 +2,8 @@ use std::marker::PhantomData; use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement}; - -use crate::{ZnxInfos, alloc_aligned, cast_mut}; +use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -118,11 +117,14 @@ impl Scalar { pub fn as_vec_znx(&self) -> VecZnx { VecZnx { - n: self.n, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr, + inner: ZnxBase { + n: self.n, + rows: 1, + cols: 1, + size: 1, + data: Vec::new(), + ptr: self.ptr as *mut u8, + }, } } } @@ -159,7 +161,7 @@ pub struct ScalarZnxDft { /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. impl ScalarZnxDft { pub fn new(module: &Module) -> Self { - module.new_svp_ppol() + module.new_scalar_znx_dft() } /// Returns the ring degree of the [SvpPPol]. @@ -168,14 +170,14 @@ impl ScalarZnxDft { } pub fn bytes_of(module: &Module) -> usize { - module.bytes_of_svp_ppol() + module.bytes_of_scalar_znx_dft() } pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert_alignement(bytes.as_ptr()); - assert_eq!(bytes.len(), module.bytes_of_svp_ppol()); + assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft()); } unsafe { Self { @@ -191,7 +193,7 @@ impl ScalarZnxDft { #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol()); + assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft()); } Self { n: module.n(), @@ -209,33 +211,33 @@ impl ScalarZnxDft { pub trait ScalarZnxDftOps { /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> ScalarZnxDft; + fn new_scalar_znx_dft(&self) -> ScalarZnxDft; /// Returns the minimum number of bytes necessary to allocate /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. - fn bytes_of_svp_ppol(&self) -> usize; + fn bytes_of_scalar_znx_dft(&self) -> usize; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is owned by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; /// Allocates a new [SvpPPol] from an array of bytes. /// The array of bytes is borrowed by the [SvpPPol]. /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); } impl ScalarZnxDftOps for Module { - fn new_svp_ppol(&self) -> ScalarZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); + fn new_scalar_znx_dft(&self) -> ScalarZnxDft { + let mut data: Vec = alloc_aligned::(self.bytes_of_scalar_znx_dft()); let ptr: *mut u8 = data.as_mut_ptr(); ScalarZnxDft:: { data: data, @@ -245,28 +247,28 @@ impl ScalarZnxDftOps for Module { } } - fn bytes_of_svp_ppol(&self) -> usize { + fn bytes_of_scalar_znx_dft(&self) -> usize { unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } } - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { + fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { ScalarZnxDft::from_bytes(self, bytes) } - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { + fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) } - fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar) { - unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } + fn svp_prepare(&self, res: &mut ScalarZnxDft, a: &Scalar) { + unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) } } - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { unsafe { svp::svp_apply_dft( self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.size() as u64, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, a.ptr as *const svp_ppol_t, b.at_ptr(b_col, 0), b.size() as u64, diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 4e2a512..a1946ab 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,4 +1,5 @@ -use crate::{Encoding, VecZnx, ZnxInfos}; +use crate::znx_base::ZnxInfos; +use crate::{Encoding, VecZnx}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 53aeb39..125f32e 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,12 +1,13 @@ use crate::Backend; -use crate::ZnxBase; +use crate::Module; +use crate::assert_alignement; use crate::cast_mut; use crate::ffi::znx; -use crate::switch_degree; -use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout}; -use crate::{alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; use std::cmp::min; +pub const VEC_ZNX_ROWS: usize = 1; + /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array @@ -17,56 +18,54 @@ use std::cmp::min; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. -#[derive(Clone)] pub struct VecZnx { - /// Polynomial degree. - pub n: usize, - - /// The number of polynomials - pub cols: usize, - - /// The number of size per polynomial (a.k.a small polynomials). - pub size: usize, - - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, - - /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). - pub ptr: *mut i64, + pub inner: ZnxBase, } -impl ZnxInfos for VecZnx { - fn n(&self) -> usize { - self.n +impl GetZnxBase for VecZnx { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn rows(&self) -> usize { - 1 + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } +} - fn cols(&self) -> usize { - self.cols - } +impl ZnxInfos for VecZnx {} - fn size(&self) -> usize { - self.size +impl ZnxSliceSize for VecZnx { + fn sl(&self) -> usize { + self.cols() * self.n() } } impl ZnxLayout for VecZnx { type Scalar = i64; - - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr - } - - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr - } } impl ZnxBasics for VecZnx {} +impl ZnxAlloc for VecZnx { + type Scalar = i64; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { + debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); + VecZnx { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), + } + } + + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_ROWS, + "rows != {} not supported for VecZnx", + VEC_ZNX_ROWS + ); + module.n() * cols * size * size_of::() + } +} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { data_b[..size].copy_from_slice(&data_a[..size]) } -impl ZnxBase for VecZnx { - type Scalar = i64; - - /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - fn new(module: &Module, cols: usize, size: usize) -> Self { - let n: usize = module.n(); - #[cfg(debug_assertions)] - { - assert!(n > 0); - assert!(n & (n - 1) == 0); - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - cols: cols, - size: size, - data: data, - ptr: ptr, - } - } - - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - module.n() * cols * size * size_of::() - } - - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[Self::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the size of data is equal to [Self::bytes_of]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - let n: usize = module.n(); - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - cols: cols, - size: size, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - n: module.n(), - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } - } -} - impl VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// @@ -165,11 +90,12 @@ impl VecZnx { } if !self.borrowing() { - self.data + self.inner + .data .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); } - self.size -= k / log_base2k; + self.inner.size -= k / log_base2k; let k_rem: usize = k % log_base2k; @@ -185,10 +111,6 @@ impl VecZnx { copy_vec_znx_from(self, a); } - pub fn borrowing(&self) -> bool { - self.data.len() == 0 - } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { normalize(log_base2k, self, carry) } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 67b75a2..cbcd4b9 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,115 +1,71 @@ use crate::ffi::vec_znx_big; -use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, NTT120}; use std::marker::PhantomData; +const VEC_ZNX_BIG_ROWS: usize = 1; + pub struct VecZnxBig { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub cols: usize, - pub size: usize, + pub inner: ZnxBase, pub _marker: PhantomData, } -impl ZnxBasics for VecZnxBig {} - -impl ZnxBase for VecZnxBig { - type Scalar = u8; - - fn new(module: &Module, cols: usize, size: usize) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); - let ptr: *mut Self::Scalar = data.as_mut_ptr(); - Self { - data: data, - ptr: ptr, - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } +impl GetZnxBase for VecZnxBig { + fn znx(&self) -> &ZnxBase { + &self.inner } - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } - } - - /// Returns a new [VecZnxBig] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()) - }; - unsafe { - Self { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } } -impl ZnxInfos for VecZnxBig { - fn n(&self) -> usize { - self.n +impl ZnxInfos for VecZnxBig {} + +impl ZnxAlloc for VecZnxBig { + type Scalar = u8; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + VecZnxBig { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), + _marker: PhantomData, + } } - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn size(&self) -> usize { - self.size + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_BIG_ROWS, + "rows != {} not supported for VecZnxBig", + VEC_ZNX_BIG_ROWS + ); + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } } } impl ZnxLayout for VecZnxBig { type Scalar = i64; +} - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr as *const Self::Scalar - } +impl ZnxLayout for VecZnxBig { + type Scalar = i128; +} - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr as *mut Self::Scalar +impl ZnxBasics for VecZnxBig {} + +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + self.n() } } +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + self.n() * 4 + } +} + +impl ZnxBasics for VecZnxBig {} + impl VecZnxBig { pub fn print(&self, n: usize) { (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index e59fda1..9c6feee 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,5 +1,6 @@ -use crate::ffi::vec_znx; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; +use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -17,7 +18,7 @@ pub trait VecZnxBigOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -41,74 +42,74 @@ pub trait VecZnxBigOps { fn vec_znx_big_add( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small( &self, res: &mut VecZnxBig, - col_res: usize, - a: &VecZnx, - col_a: usize, - b: &VecZnxBig, - col_b: usize, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + b: &VecZnx, + b_col: usize, ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `a` to `b` and stores the result on `c`. fn vec_znx_big_sub( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts `b` to `a` and stores the result on `c`. fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnx, - col_b: usize, + b_col: usize, ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; @@ -123,44 +124,44 @@ pub trait VecZnxBigOps { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } impl VecZnxBigOps for Module { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, size) + VecZnxBig::new(self, 1, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, size, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBig { + VecZnxBig::from_bytes(self, 1, cols, size, bytes) } fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) + VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes) } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBig::bytes_of(self, cols, size) + VecZnxBig::bytes_of(self, 1, cols, size) } fn vec_znx_big_add( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -170,36 +171,33 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_add( + vec_znx_big::vec_znx_big_add( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } fn vec_znx_big_sub( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -209,43 +207,40 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } fn vec_znx_big_sub_small_b( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, b: &VecZnx, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -255,36 +250,34 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub_small_b( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } fn vec_znx_big_sub_small_a( &self, res: &mut VecZnxBig, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, b: &VecZnxBig, - col_b: usize, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -294,36 +287,34 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_sub( + vec_znx_big::vec_znx_big_sub_small_a( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t, b.size() as u64, - b.sl() as u64, ) } } - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } fn vec_znx_big_add_small( &self, res: &mut VecZnxBig, - col_res: usize, - a: &VecZnx, - col_a: usize, - b: &VecZnxBig, - col_b: usize, + res_col: usize, + a: &VecZnxBig, + a_col: usize, + b: &VecZnx, + b_col: usize, ) { #[cfg(debug_assertions)] { @@ -333,25 +324,23 @@ impl VecZnxBigOps for Module { assert_ne!(a.as_ptr(), b.as_ptr()); } unsafe { - vec_znx::vec_znx_add( + vec_znx_big::vec_znx_big_add_small( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, a_col: usize) { + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res); + Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } @@ -363,9 +352,9 @@ impl VecZnxBigOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnxBig, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ) { #[cfg(debug_assertions)] @@ -376,44 +365,41 @@ impl VecZnxBigOps for Module { assert_alignement(tmp_bytes.as_ptr()); } unsafe { - vec_znx::vec_znx_normalize_base2k( + vec_znx_big::vec_znx_big_normalize_base2k( self.ptr, log_base2k as u64, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, tmp_bytes.as_mut_ptr(), ); } } - fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, res_col: usize, a: &VecZnxBig, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); } unsafe { - vec_znx::vec_znx_automorphism( + vec_znx_big::vec_znx_big_automorphism( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.size() as u64, - res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.size() as u64, - a.sl() as u64, ) } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize) { unsafe { let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 1b88af5..09ee971 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,129 +1,54 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; -use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; -use crate::{VecZnx, alloc_aligned}; +use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, VecZnxBig}; use std::marker::PhantomData; +const VEC_ZNX_DFT_ROWS: usize = 1; + pub struct VecZnxDft { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub cols: usize, - pub size: usize, + inner: ZnxBase, pub _marker: PhantomData, } -impl ZnxBase for VecZnxDft { +impl GetZnxBase for VecZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner + } + + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner + } +} + +impl ZnxInfos for VecZnxDft {} + +impl ZnxAlloc for VecZnxDft { type Scalar = u8; - fn new(module: &Module, cols: usize, size: usize) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); - let ptr: *mut Self::Scalar = data.as_mut_ptr(); - Self { - data: data, - ptr: ptr, - n: module.n(), - size: size, - cols: cols, + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + VecZnxDft { + inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } - } - - /// Returns a new [VecZnxDft] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()) - } - unsafe { - Self { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } - } - - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { - #[cfg(debug_assertions)] - { - assert!(cols > 0); - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); - assert_alignement(bytes.as_ptr()); - } - Self { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - cols: cols, - size: size, - _marker: PhantomData, - } - } -} - -impl VecZnxDft { - /// Cast a [VecZnxDft] into a [VecZnxBig]. - /// The returned [VecZnxBig] shares the backing array - /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig:: { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - cols: self.cols, - size: self.size, - _marker: PhantomData, - } - } -} - -impl ZnxInfos for VecZnxDft { - fn n(&self) -> usize { - self.n - } - - fn rows(&self) -> usize { - 1 - } - - fn cols(&self) -> usize { - self.cols - } - - fn size(&self) -> usize { - self.size + fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { + debug_assert_eq!( + _rows, VEC_ZNX_DFT_ROWS, + "rows != {} not supported for VecZnxDft", + VEC_ZNX_DFT_ROWS + ); + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } } impl ZnxLayout for VecZnxDft { type Scalar = f64; +} - fn as_ptr(&self) -> *const Self::Scalar { - self.ptr as *const Self::Scalar - } - - fn as_mut_ptr(&mut self) -> *mut Self::Scalar { - self.ptr as *mut Self::Scalar +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + self.n() } } @@ -133,225 +58,21 @@ impl VecZnxDft { } } -pub trait VecZnxDftOps { - /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. - fn vec_znx_idft_tmp_bytes(&self) -> usize; - - /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); - - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); - - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); - - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; -} - -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { - VecZnxDft::::new(&self, cols, size) - } - - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, cols, size, tmp_bytes) - } - - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes) - } - - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, size) - } - - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_dft_t, - a.poly_count() as u64, - ) +impl VecZnxDft { + /// Cast a [VecZnxDft] into a [VecZnxBig]. + /// The returned [VecZnxBig] shares the backing array + /// with the original [VecZnxDft]. + pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig:: { + inner: ZnxBase { + data: Vec::new(), + ptr: self.ptr(), + n: self.n(), + rows: self.rows(), + cols: self.cols(), + size: self.size(), + }, + _marker: PhantomData, } } - - fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } - } - - /// b <- DFT(a) - /// - /// # Panics - /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { - unsafe { - vec_znx_dft::vec_znx_dft( - self.ptr, - b.ptr as *mut vec_znx_dft_t, - b.size() as u64, - a.as_ptr(), - a.size() as u64, - (a.n() * a.cols()) as u64, - ) - } - } - - // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_idft_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_dft::vec_znx_idft( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - b.ptr as *mut vec_znx_dft_t, - b.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - [0u8; 0].as_mut_ptr(), - ); - } - } - - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_dft_automorphism_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } - println!("{}", a.poly_count()); - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - a.ptr as *mut vec_znx_dft_t, - a.poly_count() as u64, - a.ptr as *const vec_znx_dft_t, - a.poly_count() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize } - } -} - -#[cfg(test)] -mod tests { - use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned}; - use itertools::izip; - use sampling::source::Source; - - #[test] - fn test_automorphism_dft() { - let n: usize = 8; - let module: Module = Module::::new(n); - - let size: usize = 2; - let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, size); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, size); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, size); - - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, size, &mut source); - - let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); - - let p: i64 = -5; - - // a_dft <- DFT(a) - module.vec_znx_dft(&mut a_dft, &a); - - // a_dft <- AUTO(a_dft) - module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); - - // a <- AUTO(a) - module.vec_znx_automorphism_inplace(p, &mut a, 0); - - // b_dft <- DFT(AUTO(a)) - module.vec_znx_dft(&mut b_dft, &a); - - let a_f64: &[f64] = a_dft.raw(); - let b_f64: &[f64] = b_dft.raw(); - izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { - assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); - }); - - module.free() - } } diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs new file mode 100644 index 0000000..57b3777 --- /dev/null +++ b/base2k/src/vec_znx_dft_ops.rs @@ -0,0 +1,140 @@ +use crate::ffi::vec_znx_big; +use crate::ffi::vec_znx_dft; +use crate::znx_base::ZnxAlloc; +use crate::znx_base::ZnxInfos; +use crate::znx_base::ZnxLayout; +use crate::znx_base::ZnxSliceSize; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement}; + +pub trait VecZnxDftOps { + /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// Behavior: the backing array is only borrowed. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. + fn vec_znx_idft_tmp_bytes(&self) -> usize; + + /// b <- IDFT(a), uses a as scratch space. + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_cols: usize); + + fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]); + + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); +} + +impl VecZnxDftOps for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { + VecZnxDft::::new(&self, 1, cols, size) + } + + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDft { + VecZnxDft::from_bytes(self, 1, cols, size, bytes) + } + + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes) + } + + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + VecZnxDft::bytes_of(&self, 1, cols, size) + } + + fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig, res_col: usize, a: &mut VecZnxDft, a_col: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(res.poly_count(), a.poly_count()); + } + + unsafe { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, + res.size() as u64, + a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, + a.size() as u64, + ) + } + } + + fn vec_znx_idft_tmp_bytes(&self) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } + } + + /// b <- DFT(a) + /// + /// # Panics + /// If b.cols < a_cols + fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize) { + unsafe { + vec_znx_dft::vec_znx_dft( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. + fn vec_znx_idft(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnxDft, a_col: usize, tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert!( + tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_idft_tmp_bytes(self) + ); + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_dft::vec_znx_idft( + self.ptr, + res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, + res.size() as u64, + a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t, + a.size() as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 9f2d43a..7ee1529 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,5 +1,6 @@ use crate::ffi::vec_znx; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; +use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -19,7 +20,7 @@ pub trait VecZnxOps { /// /// # Panic /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx; /// Instantiates a new [VecZnx] from a slice of bytes. /// The returned [VecZnx] does take ownership of the slice of bytes. @@ -107,19 +108,19 @@ pub trait VecZnxOps { impl VecZnxOps for Module { fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { - VecZnx::new(self, cols, size) + VecZnx::new(self, VEC_ZNX_ROWS, cols, size) } fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnx::bytes_of(self, cols, size) + VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size) } - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self, cols, size, bytes) + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnx { + VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes) } fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) + VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes) } fn vec_znx_normalize_tmp_bytes(&self) -> usize { diff --git a/base2k/src/commons.rs b/base2k/src/znx_base.rs similarity index 67% rename from base2k/src/commons.rs rename to base2k/src/znx_base.rs index d5f60ee..64ad85f 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/znx_base.rs @@ -1,10 +1,37 @@ -use crate::{Backend, Module, assert_alignement, cast_mut}; +use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut}; use itertools::izip; use std::cmp::min; -pub trait ZnxInfos { +pub struct ZnxBase { + /// The ring degree + pub n: usize, + + /// The number of rows (in the third dimension) + pub rows: usize, + + /// The number of polynomials + pub cols: usize, + + /// The number of size per polynomial (a.k.a small polynomials). + pub size: usize, + + /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. + pub data: Vec, + + /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). + pub ptr: *mut u8, +} + +pub trait GetZnxBase { + fn znx(&self) -> &ZnxBase; + fn znx_mut(&mut self) -> &mut ZnxBase; +} + +pub trait ZnxInfos: GetZnxBase { /// Returns the ring degree of the polynomials. - fn n(&self) -> usize; + fn n(&self) -> usize { + self.znx().n + } /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { @@ -12,41 +39,104 @@ pub trait ZnxInfos { } /// Returns the number of rows. - fn rows(&self) -> usize; - + fn rows(&self) -> usize { + self.znx().rows + } /// Returns the number of polynomials in each row. - fn cols(&self) -> usize; + fn cols(&self) -> usize { + self.znx().cols + } /// Returns the number of size per polynomial. - fn size(&self) -> usize; + fn size(&self) -> usize { + self.znx().size + } + + fn data(&self) -> &[u8] { + &self.znx().data + } + + fn ptr(&self) -> *mut u8 { + self.znx().ptr + } /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } +} +pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. - fn sl(&self) -> usize { - self.n() * self.cols() + fn sl(&self) -> usize; +} + +impl ZnxBase { + pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); + res.data = bytes; + res + } + + pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { + #[cfg(debug_assertions)] + { + assert_eq!(n & (n - 1), 0, "n must be a power of two"); + assert!(n > 0, "n must be greater than 0"); + assert!(rows > 0, "rows must be greater than 0"); + assert!(cols > 0, "cols must be greater than 0"); + assert!(size > 0, "size must be greater than 0"); + } + Self { + n: n, + rows: rows, + cols: cols, + size: size, + data: Vec::new(), + ptr: bytes.as_mut_ptr(), + } } } -pub trait ZnxBase { +pub trait ZnxAlloc +where + Self: Sized + ZnxInfos, +{ type Scalar; - fn new(module: &Module, cols: usize, size: usize) -> Self; - fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, cols: usize, size: usize) -> usize; + fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { + let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); + Self::from_bytes(module, rows, cols, size, bytes) + } + + fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { + let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); + res.znx_mut().data = bytes; + res + } + + fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + + fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; } + pub trait ZnxLayout: ZnxInfos { type Scalar; + /// Returns true if the receiver is only borrowing the data. + fn borrowing(&self) -> bool { + self.znx().data.len() == 0 + } + /// Returns a non-mutable pointer to the underlying coefficients array. - fn as_ptr(&self) -> *const Self::Scalar; + fn as_ptr(&self) -> *const Self::Scalar { + self.znx().ptr as *const Self::Scalar + } /// Returns a mutable pointer to the underlying coefficients array. - fn as_mut_ptr(&mut self) -> *mut Self::Scalar; + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.znx_mut().ptr as *mut Self::Scalar + } /// Returns a non-mutable reference to the entire underlying coefficient array. fn raw(&self) -> &[Self::Scalar] { diff --git a/rlwe/Cargo.toml b/rlwe/Cargo.toml index a8b8207..0822281 100644 --- a/rlwe/Cargo.toml +++ b/rlwe/Cargo.toml @@ -1,5 +1,3 @@ -cargo-features = ["edition2024"] - [package] name = "rlwe" version = "0.1.0" diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index d76e356..95a935f 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -20,7 +20,7 @@ pub struct AutomorphismKey { } pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) + module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) } impl Parameters { @@ -103,10 +103,10 @@ impl AutomorphismKey { tmp_bytes: &mut [u8], ) -> Vec { let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); - let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); + let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft()); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); + let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes); let mut keys: Vec = Vec::new(); @@ -116,7 +116,7 @@ impl AutomorphismKey { let p_inv: i64 = module.galois_element_inv(*pi); module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); - module.svp_prepare(&mut sk_out, &sk_auto); + module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto); encrypt_grlwe_sk( module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, ); diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 6017159..511f755 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -20,7 +20,7 @@ impl SecretKey { } pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { - module.svp_prepare(sk_ppol, &self.0) + module.scalar_znx_dft_prepare(sk_ppol, &self.0) } }