diff --git a/.vscode/settings.json b/.vscode/settings.json index bf18bf9..d32c5df 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -57,6 +57,15 @@ "xloctime": "cpp", "xmemory": "cpp", "xtr1common": "cpp", - "vec_znx_arithmetic_private.h": "c" + "vec_znx_arithmetic_private.h": "c", + "reim4_arithmetic.h": "c", + "array": "c", + "string_view": "c" + }, + "github.copilot.enable": { + "*": false, + "plaintext": false, + "markdown": false, + "scminput": false } } \ No newline at end of file diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 592112f..925c5c4 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, - VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MODULETYPE, + VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND, }; use itertools::izip; use sampling::source::Source; @@ -11,7 +11,7 @@ fn main() { let cols: usize = 3; let msg_cols: usize = 2; let log_scale: usize = msg_cols * log_base2k - 5; - let module: Module = Module::new(n, MODULETYPE::FFT64); + let module: Module = Module::new(n, BACKEND::FFT64); let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index cb2ba58..1280cc6 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,13 +1,13 @@ use base2k::{ alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, MODULETYPE, + VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, BACKEND, }; fn main() { let log_n: i32 = 5; let n: usize = 1 << log_n; - let module: Module = Module::new(n, MODULETYPE::FFT64); + let module: Module = Module::new(n, BACKEND::FFT64); let log_base2k: usize = 15; let cols: usize = 5; let log_k: usize = log_base2k * cols - 5; diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 5461131..515d616 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 546113166e0e204cdfcd7a78ed96b6df7c457e40 +Subproject commit 515d616a8ba858b7e858a63dc0fa768eb70ebb99 diff --git a/base2k/src/bindgen.rs b/base2k/src/bindgen.rs deleted file mode 100644 index 39e4ca9..0000000 --- a/base2k/src/bindgen.rs +++ /dev/null @@ -1,36 +0,0 @@ -/* -[build-dependencies] -bindgen ="0.71.1" - -//use bindgen; -//use std::env; -//use std::fs; -//use std::path::PathBuf; -//use std::time::SystemTime; - -// Path to the C header file -let header_paths: [&str; 2] = [ - "spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h", - "spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h", -]; - -let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap()); -let bindings_file = out_path.join("bindings.rs"); - -let mut builder: bindgen::Builder = bindgen::Builder::default(); -for header in header_paths { - builder = builder.header(header); -} - -let bindings = builder - .generate_comments(false) // Optional: includes comments in bindings - .generate_inline_functions(true) // Optional: includes inline functions - .generate() - .expect("Unable to generate bindings"); - -// Write the bindings to the OUT_DIR -bindings - .write_to_file(&bindings_file) - .expect("Couldn't write bindings!"); - -*/ \ No newline at end of file diff --git a/base2k/src/ffi/reim.rs b/base2k/src/ffi/reim.rs index 7993ee0..ebd9673 100644 --- a/base2k/src/ffi/reim.rs +++ b/base2k/src/ffi/reim.rs @@ -59,40 +59,40 @@ pub struct reim_to_znx64_precomp { } pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp; unsafe extern "C" { - pub fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; + pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; } unsafe extern "C" { - pub fn reim_fft_precomp_get_buffer( + pub unsafe fn reim_fft_precomp_get_buffer( tables: *const REIM_FFT_PRECOMP, buffer_index: u32, ) -> *mut f64; } unsafe extern "C" { - pub fn new_reim_fft_buffer(m: u32) -> *mut f64; + pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64; } unsafe extern "C" { - pub fn delete_reim_fft_buffer(buffer: *mut f64); + pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64); } unsafe extern "C" { - pub fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64); + pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64); } unsafe extern "C" { - pub fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; + pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; } unsafe extern "C" { - pub fn reim_ifft_precomp_get_buffer( + pub unsafe fn reim_ifft_precomp_get_buffer( tables: *const REIM_IFFT_PRECOMP, buffer_index: u32, ) -> *mut f64; } unsafe extern "C" { - pub fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); + pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); } unsafe extern "C" { - pub fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; + pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; } unsafe extern "C" { - pub fn reim_fftvec_mul( + pub unsafe fn reim_fftvec_mul( tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, @@ -100,10 +100,10 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; + pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; } unsafe extern "C" { - pub fn reim_fftvec_addmul( + pub unsafe fn reim_fftvec_addmul( tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, @@ -111,27 +111,30 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP; + pub unsafe fn new_reim_from_znx32_precomp( + m: u32, + log2bound: u32, + ) -> *mut REIM_FROM_ZNX32_PRECOMP; } unsafe extern "C" { - pub fn reim_from_znx32( + pub unsafe fn reim_from_znx32( tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32, ); } unsafe extern "C" { - pub fn reim_from_znx64( + pub unsafe fn reim_from_znx64( tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64, ); } unsafe extern "C" { - pub fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; + pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; } unsafe extern "C" { - pub fn reim_from_znx64_simple( + pub unsafe fn reim_from_znx64_simple( m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, @@ -139,58 +142,64 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; + pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; } unsafe extern "C" { - pub fn reim_from_tnx32( + pub unsafe fn reim_from_tnx32( tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32, ); } unsafe extern "C" { - pub fn new_reim_to_tnx32_precomp( + pub unsafe fn new_reim_to_tnx32_precomp( m: u32, divisor: f64, log2overhead: u32, ) -> *mut REIM_TO_TNX32_PRECOMP; } unsafe extern "C" { - pub fn reim_to_tnx32( + pub unsafe fn reim_to_tnx32( tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void, ); } unsafe extern "C" { - pub fn new_reim_to_tnx_precomp( + pub unsafe fn new_reim_to_tnx_precomp( m: u32, divisor: f64, log2overhead: u32, ) -> *mut REIM_TO_TNX_PRECOMP; } unsafe extern "C" { - pub fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); + pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); } unsafe extern "C" { - pub fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64); + pub unsafe fn reim_to_tnx_simple( + m: u32, + divisor: f64, + log2overhead: u32, + r: *mut f64, + a: *const f64, + ); } unsafe extern "C" { - pub fn new_reim_to_znx64_precomp( + pub unsafe fn new_reim_to_znx64_precomp( m: u32, divisor: f64, log2bound: u32, ) -> *mut REIM_TO_ZNX64_PRECOMP; } unsafe extern "C" { - pub fn reim_to_znx64( + pub unsafe fn reim_to_znx64( precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void, ); } unsafe extern "C" { - pub fn reim_to_znx64_simple( + pub unsafe fn reim_to_znx64_simple( m: u32, divisor: f64, log2bound: u32, @@ -199,13 +208,13 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); + pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); } unsafe extern "C" { - pub fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void); + pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void); } unsafe extern "C" { - pub fn reim_fftvec_mul_simple( + pub unsafe fn reim_fftvec_mul_simple( m: u32, r: *mut ::std::os::raw::c_void, a: *const ::std::os::raw::c_void, @@ -213,7 +222,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn reim_fftvec_addmul_simple( + pub unsafe fn reim_fftvec_addmul_simple( m: u32, r: *mut ::std::os::raw::c_void, a: *const ::std::os::raw::c_void, @@ -221,7 +230,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn reim_from_znx32_simple( + pub unsafe fn reim_from_znx32_simple( m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, @@ -229,10 +238,10 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); + pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); } unsafe extern "C" { - pub fn reim_to_tnx32_simple( + pub unsafe fn reim_to_tnx32_simple( m: u32, divisor: f64, log2overhead: u32, diff --git a/base2k/src/ffi/vec_znx_big.rs b/base2k/src/ffi/vec_znx_big.rs index f2da750..e1222c3 100644 --- a/base2k/src/ffi/vec_znx_big.rs +++ b/base2k/src/ffi/vec_znx_big.rs @@ -2,10 +2,10 @@ use crate::ffi::module::MODULE; #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct vec_znx_bigcoeff_t { +pub struct vec_znx_big_t { _unused: [u8; 0], } -pub type VEC_ZNX_BIG = vec_znx_bigcoeff_t; +pub type VEC_ZNX_BIG = vec_znx_big_t; unsafe extern "C" { pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; diff --git a/base2k/src/ffi/vmp.rs b/base2k/src/ffi/vmp.rs index a0e6a92..9035667 100644 --- a/base2k/src/ffi/vmp.rs +++ b/base2k/src/ffi/vmp.rs @@ -1,4 +1,5 @@ use crate::ffi::module::MODULE; +use crate::ffi::vec_znx_big::VEC_ZNX_BIG; use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; #[repr(C)] @@ -103,6 +104,39 @@ unsafe extern "C" { ); } +unsafe extern "C" { + pub unsafe fn vmp_prepare_row_dft( + module: *const MODULE, + pmat: *mut VMP_PMAT, + row: *const VEC_ZNX_DFT, + row_i: u64, + nrows: u64, + ncols: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_extract_row_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + pmat: *const VMP_PMAT, + row_i: u64, + nrows: u64, + ncols: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_extract_row( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + pmat: *const VMP_PMAT, + row_i: u64, + nrows: u64, + ncols: u64, + ); +} + unsafe extern "C" { pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; } diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 7a16147..23abe60 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -3,7 +3,7 @@ use crate::GALOISGENERATOR; #[derive(Copy, Clone)] #[repr(u8)] -pub enum MODULETYPE { +pub enum BACKEND { FFT64, NTT120, } @@ -11,17 +11,17 @@ pub enum MODULETYPE { pub struct Module { pub ptr: *mut MODULE, pub n: usize, - pub backend: MODULETYPE, + pub backend: BACKEND, } impl Module { // Instantiates a new module. - pub fn new(n: usize, module_type: MODULETYPE) -> Self { + pub fn new(n: usize, module_type: BACKEND) -> Self { unsafe { let module_type_u32: u32; match module_type { - MODULETYPE::FFT64 => module_type_u32 = 0, - MODULETYPE::NTT120 => module_type_u32 = 1, + BACKEND::FFT64 => module_type_u32 = 0, + BACKEND::NTT120 => module_type_u32 = 1, } let m: *mut module_info_t = new_module_info(n as u64, module_type_u32); if m.is_null() { @@ -35,7 +35,7 @@ impl Module { } } - pub fn backend(&self) -> MODULETYPE { + pub fn backend(&self) -> BACKEND { self.backend } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index aeb64f6..6f45bac 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -540,31 +540,6 @@ impl VecZnxOps for Module { /// # Panics /// /// The method will panic if the argument `a` is greater than `a.cols()`. - /// - /// # Example - /// ``` - /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps}; - /// use itertools::izip; - /// - /// let n: usize = 8; // polynomial degree - /// let module = Module::new(n, MODULETYPE::FFT64); - /// let mut a: VecZnx = module.new_vec_znx(2); - /// let mut b: VecZnx = module.new_vec_znx(2); - /// let mut c: VecZnx = module.new_vec_znx(2); - /// - /// (0..a.cols()).for_each(|i|{ - /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ - /// *x = i as i64 - /// }) - /// }); - /// - /// module.vec_znx_automorphism(-1, &mut b, &a, 1); // X^i -> X^(-i) - /// let col = c.at_mut(0); - /// (1..col.len()).for_each(|i|{ - /// col[n-i] = -(i as i64) - /// }); - /// izip!(b.raw().iter(), c.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - /// ``` fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) { debug_assert_eq!(a.n(), self.n()); debug_assert_eq!(b.n(), self.n()); @@ -594,30 +569,6 @@ impl VecZnxOps for Module { /// # Panics /// /// The method will panic if the argument `cols` is greater than `self.cols()`. - /// - /// # Example - /// ``` - /// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps}; - /// use itertools::izip; - /// - /// let n: usize = 8; // polynomial degree - /// let module = Module::new(n, MODULETYPE::FFT64); - /// let mut a: VecZnx = VecZnx::new(n, 2); - /// let mut b: VecZnx = VecZnx::new(n, 2); - /// - /// (0..a.cols()).for_each(|i|{ - /// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{ - /// *x = i as i64 - /// }) - /// }); - /// - /// module.vec_znx_automorphism_inplace(-1, &mut a, 1); // X^i -> X^(-i) - /// let col = b.at_mut(0); - /// (1..col.len()).for_each(|i|{ - /// col[n-i] = -(i as i64) - /// }); - /// izip!(a.raw().iter(), b.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - /// ``` fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) { debug_assert_eq!(a.n(), self.n()); debug_assert!(a.cols() >= a_cols); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 90942c7..8b63992 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,12 +1,12 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_bigcoeff_t}; -use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE}; +use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; +use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, BACKEND}; pub struct VecZnxBig { pub data: Vec, pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub backend: MODULETYPE, + pub backend: BACKEND, } impl VecZnxBig { @@ -62,9 +62,19 @@ impl VecZnxBig { self.cols } - pub fn backend(&self) -> MODULETYPE { + pub fn backend(&self) -> BACKEND { self.backend } + + /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. + /// When using [`crate::FFT64`] as backend, `T` should be [f64]. + /// When using [`crate::NTT120`] as backend, `T` should be [i64]. + /// The length of the returned array is cols * n. + pub fn raw(&self, module: &Module) -> &[T] { + let ptr: *const T = self.ptr as *const T; + let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } } pub trait VecZnxBigOps { @@ -162,12 +172,12 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, a.n() as u64, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, ) } @@ -177,12 +187,12 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_sub_small_a( self.ptr, - c.ptr as *mut vec_znx_bigcoeff_t, + c.ptr as *mut vec_znx_big_t, c.cols() as u64, a.as_ptr(), a.cols() as u64, a.n() as u64, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, ) } @@ -192,9 +202,9 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, - c.ptr as *mut vec_znx_bigcoeff_t, + c.ptr as *mut vec_znx_big_t, c.cols() as u64, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, @@ -207,9 +217,9 @@ impl VecZnxBigOps for Module { unsafe { vec_znx_big::vec_znx_big_add_small( self.ptr, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, a.as_ptr(), a.cols() as u64, @@ -246,7 +256,7 @@ impl VecZnxBigOps for Module { b.as_mut_ptr(), b.cols() as u64, b.n() as u64, - a.ptr as *mut vec_znx_bigcoeff_t, + a.ptr as *mut vec_znx_big_t, a.cols() as u64, tmp_bytes.as_mut_ptr(), ) @@ -284,7 +294,7 @@ impl VecZnxBigOps for Module { res.as_mut_ptr(), res.cols() as u64, res.n() as u64, - a.ptr as *mut vec_znx_bigcoeff_t, + a.ptr as *mut vec_znx_big_t, a_range_begin as u64, a_range_xend as u64, a_range_step as u64, @@ -298,9 +308,9 @@ impl VecZnxBigOps for Module { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, - a.ptr as *mut vec_znx_bigcoeff_t, + a.ptr as *mut vec_znx_big_t, a.cols() as u64, ); } @@ -311,9 +321,9 @@ impl VecZnxBigOps for Module { vec_znx_big::vec_znx_big_automorphism( self.ptr, gal_el, - a.ptr as *mut vec_znx_bigcoeff_t, + a.ptr as *mut vec_znx_big_t, a.cols() as u64, - a.ptr as *mut vec_znx_bigcoeff_t, + a.ptr as *mut vec_znx_big_t, a.cols() as u64, ); } diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 275b401..4ad8525 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,15 +1,15 @@ -use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t; +use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; use crate::{alloc_aligned, VecZnx}; -use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE}; +use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND}; pub struct VecZnxDft { pub data: Vec, pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub backend: MODULETYPE, + pub backend: BACKEND, } impl VecZnxDft { @@ -69,7 +69,7 @@ impl VecZnxDft { self.cols } - pub fn backend(&self) -> MODULETYPE { + pub fn backend(&self) -> BACKEND { self.backend } @@ -133,17 +133,17 @@ pub trait VecZnxDftOps { fn vec_znx_idft_tmp_bytes(&self) -> usize; /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize); + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize); fn vec_znx_idft( &self, b: &mut VecZnxBig, - a: &mut VecZnxDft, - a_limbs: usize, + a: &VecZnxDft, + a_cols: usize, tmp_bytes: &mut [u8], ); - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize); + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize); } impl VecZnxDftOps for Module { @@ -177,20 +177,20 @@ impl VecZnxDftOps for Module { unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize } } - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { + fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) { debug_assert!( - b.cols() >= a_limbs, - "invalid c_vector: b_vector.cols()={} < a_limbs={}", + b.cols() >= a_cols, + "invalid c_vector: b_vector.cols()={} < a_cols={}", b.cols(), - a_limbs + a_cols ); unsafe { vec_znx_dft::vec_znx_idft_tmp_a( self.ptr, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, b.cols() as u64, a.ptr as *mut vec_znx_dft_t, - a_limbs as u64, + a_cols as u64, ) } } @@ -226,7 +226,7 @@ impl VecZnxDftOps for Module { fn vec_znx_idft( &self, b: &mut VecZnxBig, - a: &mut VecZnxDft, + a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8], ) { @@ -243,7 +243,7 @@ impl VecZnxDftOps for Module { a_cols ); debug_assert!( - tmp_bytes.len() <= ::vec_znx_idft_tmp_bytes(self), + tmp_bytes.len() >= ::vec_znx_idft_tmp_bytes(self), "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", tmp_bytes.len(), ::vec_znx_idft_tmp_bytes(self) @@ -255,9 +255,9 @@ impl VecZnxDftOps for Module { unsafe { vec_znx_dft::vec_znx_idft( self.ptr, - b.ptr as *mut vec_znx_bigcoeff_t, + b.ptr as *mut vec_znx_big_t, a.cols() as u64, - a.ptr as *mut vec_znx_dft_t, + a.ptr as *const vec_znx_dft_t, a_cols as u64, tmp_bytes.as_mut_ptr(), ) diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs index 7195cfb..e1c84d3 100644 --- a/base2k/src/vmp.rs +++ b/base2k/src/vmp.rs @@ -1,6 +1,9 @@ +use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE}; +use crate::{ + alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND, +}; /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// stored as a 3D matrix in the DFT domain in a single contiguous array. @@ -23,7 +26,7 @@ pub struct VmpPMat { /// The ring degree of each [VecZnxDft]. n: usize, - backend: MODULETYPE, + backend: BACKEND, } impl Infos for VmpPMat { @@ -59,7 +62,7 @@ impl VmpPMat { self.ptr } - pub fn borrowed(&self) -> bool{ + pub fn borrowed(&self) -> bool { self.data.len() == 0 } @@ -167,7 +170,7 @@ pub trait VmpPMatOps { /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); - /// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx]. + /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// /// # Arguments /// @@ -179,6 +182,35 @@ pub trait VmpPMatOps { /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); + /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. + /// + /// # Arguments + /// + /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. + /// * `a`: [VmpPMat] on which the values are encoded. + /// * `row_i`: the index of the row to extract. + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); + + /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: [VmpPMat] on which the values are encoded. + /// * `a`: the [VecZnxDft] to encode on the [VmpPMat]. + /// * `row_i`: the index of the row to prepare. + /// + /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); + + /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. + /// * `a`: [VmpPMat] on which the values are encoded. + /// * `row_i`: the index of the row to extract. + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); + /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// /// # Arguments @@ -375,6 +407,60 @@ impl VmpPMatOps for Module { } } + fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.cols(), b.cols()); + } + unsafe { + vmp::vmp_extract_row( + self.ptr, + b.ptr as *mut vec_znx_big_t, + a.as_ptr() as *const vmp_pmat_t, + row_i as u64, + a.rows() as u64, + a.cols() as u64, + ); + } + } + + fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.cols(), b.cols()); + } + unsafe { + vmp::vmp_prepare_row_dft( + self.ptr, + b.as_mut_ptr() as *mut vmp_pmat_t, + a.ptr as *const vec_znx_dft_t, + row_i as u64, + b.rows() as u64, + b.cols() as u64, + ); + } + } + + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), b.n()); + assert_eq!(a.cols(), b.cols()); + } + unsafe { + vmp::vmp_extract_row_dft( + self.ptr, + b.ptr as *mut vec_znx_dft_t, + a.as_ptr() as *const vmp_pmat_t, + row_i as u64, + a.rows() as u64, + a.cols() as u64, + ); + } + } + fn vmp_apply_dft_tmp_bytes( &self, res_cols: usize, @@ -489,3 +575,52 @@ impl VmpPMatOps for Module { } } } + +#[cfg(test)] +mod tests { + use crate::{ + alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, + VecZnxOps, VmpPMat, VmpPMatOps, + }; + use sampling::source::Source; + + #[test] + fn vmp_prepare_row_dft() { + let module: Module = Module::new(32, crate::BACKEND::FFT64); + let vpmat_rows: usize = 4; + let vpmat_cols: usize = 5; + let log_base2k: usize = 8; + let mut a: VecZnx = module.new_vec_znx(vpmat_cols); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols); + let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols); + let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols); + let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); + let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols); + + let mut tmp_bytes: Vec = + alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols)); + + for row_i in 0..vpmat_rows { + let mut source: Source = Source::new([0u8; 32]); + module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source); + module.vec_znx_dft(&mut a_dft, &a, vpmat_cols); + module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); + + // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) + module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); + assert_eq!(vmpmat_0.raw::(), vmpmat_1.raw::()); + + // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) + module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); + assert_eq!(a_dft.raw::(&module), b_dft.raw::(&module)); + + // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) + module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); + module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes); + assert_eq!(a_big.raw::(&module), b_big.raw::(&module)); + } + + module.free(); + } +} diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index e7dd8c2..b2aeaa0 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,5 +1,5 @@ use base2k::{ - Infos, MODULETYPE, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, + Infos, BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, }; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; @@ -36,7 +36,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { for log_n in 10..11 { let params_lit: ParametersLiteral = ParametersLiteral { - backend: MODULETYPE::FFT64, + backend: BACKEND::FFT64, log_n: log_n, log_q: 32, log_p: 0, diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index 4002a95..8ae1d57 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -10,7 +10,7 @@ use sampling::source::Source; fn main() { let params_lit: ParametersLiteral = ParametersLiteral { - backend: base2k::MODULETYPE::FFT64, + backend: base2k::BACKEND::FFT64, log_n: 10, log_q: 54, log_p: 0, diff --git a/rlwe/examples/rlk_experiments.rs b/rlwe/examples/rlk_experiments.rs index a35904c..cd2331b 100644 --- a/rlwe/examples/rlk_experiments.rs +++ b/rlwe/examples/rlk_experiments.rs @@ -12,7 +12,7 @@ use sampling::source::{Source, new_seed}; fn main() { let n: usize = 32; - let module: Module = Module::new(n, base2k::MODULETYPE::FFT64); + let module: Module = Module::new(n, base2k::BACKEND::FFT64); let log_base2k: usize = 16; let log_k: usize = 32; let cols: usize = 4; diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 0139f88..3011ac4 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -97,7 +97,7 @@ mod test { plaintext::Plaintext, }; use base2k::{ - Infos, MODULETYPE, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, + Infos, BACKEND, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, }; use sampling::source::{Source, new_seed}; @@ -110,7 +110,7 @@ mod test { // Basic parameters with enough limbs to test edge cases let params_lit: ParametersLiteral = ParametersLiteral { - backend: MODULETYPE::FFT64, + backend: BACKEND::FFT64, log_n: 12, log_q: q_cols * log_base2k, log_p: p_cols * log_base2k, diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs index a0860d0..1faa5f4 100644 --- a/rlwe/src/parameters.rs +++ b/rlwe/src/parameters.rs @@ -1,7 +1,7 @@ -use base2k::module::{MODULETYPE, Module}; +use base2k::module::{BACKEND, Module}; pub struct ParametersLiteral { - pub backend: MODULETYPE, + pub backend: BACKEND, pub log_n: usize, pub log_q: usize, pub log_p: usize,