From d3a8d206473bf89000062a5290816ef93eec5fb9 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 30 Jan 2025 17:34:57 +0100 Subject: [PATCH] rewrote all bindings, removed dependency on binding generation --- Cargo.lock | 101 ---------- base2k/Cargo.toml | 3 - base2k/benches/fft.rs | 6 +- base2k/build.rs | 16 +- base2k/examples/fft.rs | 6 +- base2k/examples/rlwe_encrypt.rs | 4 +- base2k/examples/vector_matrix_product.rs | 26 ++- base2k/src/ffi/cnv.rs | 7 + base2k/src/ffi/mod.rs | 8 + base2k/src/ffi/module.rs | 20 ++ base2k/src/ffi/reim.rs | 242 +++++++++++++++++++++++ base2k/src/ffi/svp.rs | 35 ++++ base2k/src/ffi/vec_znx.rs | 101 ++++++++++ base2k/src/ffi/vec_znx_big.rs | 158 +++++++++++++++ base2k/src/ffi/vec_znx_dft.rs | 77 ++++++++ base2k/src/ffi/vmp.rs | 96 +++++++++ base2k/src/ffi/znx.rs | 89 +++++++++ base2k/src/lib.rs | 16 +- base2k/src/module.rs | 30 +-- base2k/src/scalar_vector_product.rs | 13 +- base2k/src/vec_znx.rs | 24 +-- base2k/src/vec_znx_arithmetic.rs | 2 +- base2k/src/vec_znx_big_arithmetic.rs | 28 ++- base2k/src/vec_znx_dft.rs | 24 ++- base2k/src/vector_matrix_product.rs | 97 +++++++-- 25 files changed, 1040 insertions(+), 189 deletions(-) create mode 100644 base2k/src/ffi/cnv.rs create mode 100644 base2k/src/ffi/mod.rs create mode 100644 base2k/src/ffi/module.rs create mode 100644 base2k/src/ffi/reim.rs create mode 100644 base2k/src/ffi/svp.rs create mode 100644 base2k/src/ffi/vec_znx.rs create mode 100644 base2k/src/ffi/vec_znx_big.rs create mode 100644 base2k/src/ffi/vec_znx_dft.rs create mode 100644 base2k/src/ffi/vmp.rs create mode 100644 base2k/src/ffi/znx.rs diff --git a/Cargo.lock b/Cargo.lock index 6b5fc47..c2c3c7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,7 +53,6 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" name = "base2k" version = "0.1.0" dependencies = [ - "bindgen", "criterion", "itertools 0.14.0", "rand", @@ -63,32 +62,6 @@ dependencies = [ "utils", ] -[[package]] -name = "bindgen" -version = "0.71.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "itertools 0.10.5", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn", -] - -[[package]] -name = "bitflags" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" - [[package]] name = "bumpalo" version = "3.16.0" @@ -107,15 +80,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -149,17 +113,6 @@ dependencies = [ "half", ] -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading", -] - [[package]] name = "clap" version = "4.5.23" @@ -275,12 +228,6 @@ dependencies = [ "wasi", ] -[[package]] -name = "glob" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" - [[package]] name = "half" version = "2.4.1" @@ -354,16 +301,6 @@ version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" -[[package]] -name = "libloading" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" -dependencies = [ - "cfg-if", - "windows-targets", -] - [[package]] name = "libm" version = "0.2.11" @@ -392,12 +329,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "ndarray" version = "0.16.1" @@ -413,16 +344,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - [[package]] name = "num" version = "0.4.3" @@ -581,16 +502,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "prettyplease" -version = "0.2.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" -dependencies = [ - "proc-macro2", - "syn", -] - [[package]] name = "primality-test" version = "0.3.0" @@ -747,12 +658,6 @@ dependencies = [ "utils", ] -[[package]] -name = "rustc-hash" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" - [[package]] name = "ryu" version = "1.0.18" @@ -809,12 +714,6 @@ dependencies = [ "serde", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "smallvec" version = "1.13.2" diff --git a/base2k/Cargo.toml b/base2k/Cargo.toml index 44ccddd..5e829f2 100644 --- a/base2k/Cargo.toml +++ b/base2k/Cargo.toml @@ -12,9 +12,6 @@ rand_core = {workspace = true} sampling = { path = "../sampling" } utils = { path = "../utils" } -[build-dependencies] -bindgen ="0.71.1" - [[bench]] name = "fft" harness = false \ No newline at end of file diff --git a/base2k/benches/fft.rs b/base2k/benches/fft.rs index 11027be..17e6e14 100644 --- a/base2k/benches/fft.rs +++ b/base2k/benches/fft.rs @@ -1,8 +1,4 @@ -use base2k::bindings::{ - new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp, - reim_fft_precomp_get_buffer, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp, - reim_ifft_precomp_get_buffer, -}; +use base2k::ffi::reim::*; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use std::ffi::c_void; diff --git a/base2k/build.rs b/base2k/build.rs index 0b2ffd8..c11becf 100644 --- a/base2k/build.rs +++ b/base2k/build.rs @@ -1,11 +1,16 @@ -use bindgen; -use std::env; -use std::fs; +//use bindgen; +//use std::env; +//use std::fs; use std::path::absolute; -use std::path::PathBuf; -use std::time::SystemTime; +//use std::path::PathBuf; +//use std::time::SystemTime; fn main() { + /* + + [build-dependencies] + bindgen ="0.71.1" + // Path to the C header file let header_paths: [&str; 2] = [ "spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h", @@ -43,6 +48,7 @@ fn main() { .write_to_file(&bindings_file) .expect("Couldn't write bindings!"); } + */ println!( "cargo:rustc-link-search=native={}", diff --git a/base2k/examples/fft.rs b/base2k/examples/fft.rs index 74a742a..181d066 100644 --- a/base2k/examples/fft.rs +++ b/base2k/examples/fft.rs @@ -1,8 +1,4 @@ -use base2k::bindings::{ - new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp_get_buffer, - reim_fftvec_mul_simple, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp_get_buffer, - reim_to_znx64_simple, -}; +use base2k::ffi::reim::*; use std::ffi::c_void; use std::time::Instant; diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 569b831..1a76689 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -48,7 +48,7 @@ fn main() { .for_each(|x| *x = source.next_u64n(16, 15) as i64); // m - m.set_i64(&want, 4); + m.from_i64(&want, 4); m.normalize(&mut carry); // buf_big <- m - buf_big @@ -73,7 +73,7 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.get_i64(&mut have); + res.to_i64(&mut have); let scale: f64 = (1 << log_scale) as f64; izip!(want.iter(), have.iter()) diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 7c4bccc..c93a7f07e 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -1,7 +1,7 @@ use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; fn main() { - let log_n = 4; + let log_n = 5; let n = 1 << log_n; let module: Module = Module::new::(n); @@ -14,7 +14,7 @@ fn main() { // Maximum size of the byte scratch needed let tmp_bytes: usize = module.vmp_prepare_contiguous_tmp_bytes(rows, cols) - | module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols); + | module.vmp_apply_dft_tmp_bytes(limbs, limbs, rows, cols); let mut buf: Vec = vec![0; tmp_bytes]; @@ -22,7 +22,7 @@ fn main() { a_values[1] = (1 << log_base2k) + 1; let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q); - a.set_i64(&a_values, 32); + a.from_i64(&a_values, 32); a.normalize(&mut buf); (0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i))); @@ -30,9 +30,14 @@ fn main() { let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); (0..rows).for_each(|i| { - b_mat.at_mut(i, i)[0] = ((1 << 15) + 1) << 15; + (0..cols).for_each(|j| { + b_mat.at_mut(i, j)[0] = (i * cols + j) as i64; + b_mat.at_mut(i, j)[0] = (i * cols + j) as i64; + }) }); + //b_mat.data.iter_mut().enumerate().for_each(|(i, xi)| *xi = i as i64); + println!(); (0..rows).for_each(|i| { (0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j))); @@ -42,6 +47,13 @@ fn main() { let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + (0..cols).for_each(|i| { + (0..rows).for_each(|j| println!("{} {}: {:?}", i, j, vmp_pmat.at(i, j))); + println!(); + }); + + println!("{:?}", vmp_pmat.as_f64()); + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs); module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); @@ -52,9 +64,13 @@ fn main() { module.vec_znx_big_normalize(&mut res, &c_big, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; - res.get_i64(&mut values_res); + res.to_i64(&mut values_res); (0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i))); + module.delete(); + c_dft.delete(); + vmp_pmat.delete(); + println!("{:?}", values_res) } diff --git a/base2k/src/ffi/cnv.rs b/base2k/src/ffi/cnv.rs new file mode 100644 index 0000000..be8aae3 --- /dev/null +++ b/base2k/src/ffi/cnv.rs @@ -0,0 +1,7 @@ +pub type CNV_PVEC_L = cnv_pvec_l_t; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct cnv_pvec_r_t { + _unused: [u8; 0], +} +pub type CNV_PVEC_R = cnv_pvec_r_t; diff --git a/base2k/src/ffi/mod.rs b/base2k/src/ffi/mod.rs new file mode 100644 index 0000000..57e9291 --- /dev/null +++ b/base2k/src/ffi/mod.rs @@ -0,0 +1,8 @@ +pub mod module; +pub mod reim; +pub mod svp; +pub mod vec_znx; +pub mod vec_znx_big; +pub mod vec_znx_dft; +pub mod vmp; +pub mod znx; diff --git a/base2k/src/ffi/module.rs b/base2k/src/ffi/module.rs new file mode 100644 index 0000000..755d613 --- /dev/null +++ b/base2k/src/ffi/module.rs @@ -0,0 +1,20 @@ +pub struct module_info_t { + _unused: [u8; 0], +} + +pub type module_type_t = ::std::os::raw::c_uint; +pub const module_type_t_FFT64: module_type_t = 0; +pub const module_type_t_NTT120: module_type_t = 1; +pub use self::module_type_t as MODULE_TYPE; + +pub type MODULE = module_info_t; + +unsafe extern "C" { + pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE; +} +unsafe extern "C" { + pub unsafe fn delete_module_info(module_info: *mut MODULE); +} +unsafe extern "C" { + pub unsafe fn module_get_n(module: *const MODULE) -> u64; +} diff --git a/base2k/src/ffi/reim.rs b/base2k/src/ffi/reim.rs new file mode 100644 index 0000000..7993ee0 --- /dev/null +++ b/base2k/src/ffi/reim.rs @@ -0,0 +1,242 @@ +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_fft_precomp { + _unused: [u8; 0], +} +pub type REIM_FFT_PRECOMP = reim_fft_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_ifft_precomp { + _unused: [u8; 0], +} +pub type REIM_IFFT_PRECOMP = reim_ifft_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_mul_precomp { + _unused: [u8; 0], +} +pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_addmul_precomp { + _unused: [u8; 0], +} +pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_znx32_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_znx64_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_tnx32_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_tnx32_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_tnx_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_znx64_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp; +unsafe extern "C" { + pub fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; +} +unsafe extern "C" { + pub fn reim_fft_precomp_get_buffer( + tables: *const REIM_FFT_PRECOMP, + buffer_index: u32, + ) -> *mut f64; +} +unsafe extern "C" { + pub fn new_reim_fft_buffer(m: u32) -> *mut f64; +} +unsafe extern "C" { + pub fn delete_reim_fft_buffer(buffer: *mut f64); +} +unsafe extern "C" { + pub fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64); +} +unsafe extern "C" { + pub fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; +} +unsafe extern "C" { + pub fn reim_ifft_precomp_get_buffer( + tables: *const REIM_IFFT_PRECOMP, + buffer_index: u32, + ) -> *mut f64; +} +unsafe extern "C" { + pub fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); +} +unsafe extern "C" { + pub fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; +} +unsafe extern "C" { + pub fn reim_fftvec_mul( + tables: *const REIM_FFTVEC_MUL_PRECOMP, + r: *mut f64, + a: *const f64, + b: *const f64, + ); +} +unsafe extern "C" { + pub fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; +} +unsafe extern "C" { + pub fn reim_fftvec_addmul( + tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, + r: *mut f64, + a: *const f64, + b: *const f64, + ); +} +unsafe extern "C" { + pub fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP; +} +unsafe extern "C" { + pub fn reim_from_znx32( + tables: *const REIM_FROM_ZNX32_PRECOMP, + r: *mut ::std::os::raw::c_void, + a: *const i32, + ); +} +unsafe extern "C" { + pub fn reim_from_znx64( + tables: *const REIM_FROM_ZNX64_PRECOMP, + r: *mut ::std::os::raw::c_void, + a: *const i64, + ); +} +unsafe extern "C" { + pub fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; +} +unsafe extern "C" { + pub fn reim_from_znx64_simple( + m: u32, + log2bound: u32, + r: *mut ::std::os::raw::c_void, + a: *const i64, + ); +} +unsafe extern "C" { + pub fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; +} +unsafe extern "C" { + pub fn reim_from_tnx32( + tables: *const REIM_FROM_TNX32_PRECOMP, + r: *mut ::std::os::raw::c_void, + a: *const i32, + ); +} +unsafe extern "C" { + pub fn new_reim_to_tnx32_precomp( + m: u32, + divisor: f64, + log2overhead: u32, + ) -> *mut REIM_TO_TNX32_PRECOMP; +} +unsafe extern "C" { + pub fn reim_to_tnx32( + tables: *const REIM_TO_TNX32_PRECOMP, + r: *mut i32, + a: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn new_reim_to_tnx_precomp( + m: u32, + divisor: f64, + log2overhead: u32, + ) -> *mut REIM_TO_TNX_PRECOMP; +} +unsafe extern "C" { + pub fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub fn new_reim_to_znx64_precomp( + m: u32, + divisor: f64, + log2bound: u32, + ) -> *mut REIM_TO_ZNX64_PRECOMP; +} +unsafe extern "C" { + pub fn reim_to_znx64( + precomp: *const REIM_TO_ZNX64_PRECOMP, + r: *mut i64, + a: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn reim_to_znx64_simple( + m: u32, + divisor: f64, + log2bound: u32, + r: *mut i64, + a: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn reim_fftvec_mul_simple( + m: u32, + r: *mut ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + b: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn reim_fftvec_addmul_simple( + m: u32, + r: *mut ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + b: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn reim_from_znx32_simple( + m: u32, + log2bound: u32, + r: *mut ::std::os::raw::c_void, + x: *const i32, + ); +} +unsafe extern "C" { + pub fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); +} +unsafe extern "C" { + pub fn reim_to_tnx32_simple( + m: u32, + divisor: f64, + log2overhead: u32, + r: *mut i32, + x: *const ::std::os::raw::c_void, + ); +} diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs new file mode 100644 index 0000000..71c871d --- /dev/null +++ b/base2k/src/ffi/svp.rs @@ -0,0 +1,35 @@ +use crate::ffi::module::MODULE; +use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct svp_ppol_t { + _unused: [u8; 0], +} +pub type SVP_PPOL = svp_ppol_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL; +} +unsafe extern "C" { + pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL); +} + +unsafe extern "C" { + pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64); +} + +unsafe extern "C" { + pub unsafe fn svp_apply_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + ppol: *const SVP_PPOL, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} diff --git a/base2k/src/ffi/vec_znx.rs b/base2k/src/ffi/vec_znx.rs new file mode 100644 index 0000000..897ef04 --- /dev/null +++ b/base2k/src/ffi/vec_znx.rs @@ -0,0 +1,101 @@ +use crate::ffi::module::MODULE; + +unsafe extern "C" { + pub unsafe fn vec_znx_add( + module: *const MODULE, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_automorphism( + module: *const MODULE, + p: i64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_negate( + module: *const MODULE, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_rotate( + module: *const MODULE, + p: i64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_sub( + module: *const MODULE, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} + +unsafe extern "C" { + pub fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64); +} +unsafe extern "C" { + pub fn vec_znx_copy( + module: *const MODULE, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + +unsafe extern "C" { + pub fn vec_znx_normalize_base2k( + module: *const MODULE, + log2_base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + tmp_space: *mut u8, + ); +} +unsafe extern "C" { + pub fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/base2k/src/ffi/vec_znx_big.rs b/base2k/src/ffi/vec_znx_big.rs new file mode 100644 index 0000000..f2da750 --- /dev/null +++ b/base2k/src/ffi/vec_znx_big.rs @@ -0,0 +1,158 @@ +use crate::ffi::module::MODULE; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vec_znx_bigcoeff_t { + _unused: [u8; 0], +} +pub type VEC_ZNX_BIG = vec_znx_bigcoeff_t; + +unsafe extern "C" { + pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; +} +unsafe extern "C" { + pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; +} +unsafe extern "C" { + pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); +} + +unsafe extern "C" { + pub fn vec_znx_big_add( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_add_small( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_add_small2( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_sub( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_sub_small_b( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_sub_small_a( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_sub_small2( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_normalize_base2k( + module: *const MODULE, + log2_base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + tmp_space: *mut u8, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub fn vec_znx_big_automorphism( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub fn vec_znx_big_rotate( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub fn vec_znx_big_range_normalize_base2k( + module: *const MODULE, + log2_base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const VEC_ZNX_BIG, + a_range_begin: u64, + a_range_xend: u64, + a_range_step: u64, + tmp_space: *mut u8, + ); +} +unsafe extern "C" { + pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/base2k/src/ffi/vec_znx_dft.rs b/base2k/src/ffi/vec_znx_dft.rs new file mode 100644 index 0000000..6f43683 --- /dev/null +++ b/base2k/src/ffi/vec_znx_dft.rs @@ -0,0 +1,77 @@ +use crate::ffi::module::MODULE; +use crate::ffi::vec_znx_big::VEC_ZNX_BIG; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vec_znx_dft_t { + _unused: [u8; 0], +} +pub type VEC_ZNX_DFT = vec_znx_dft_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT; +} +unsafe extern "C" { + pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT); +} + +unsafe extern "C" { + pub fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); +} +unsafe extern "C" { + pub fn vec_dft_add( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const VEC_ZNX_DFT, + a_size: u64, + b: *const VEC_ZNX_DFT, + b_size: u64, + ); +} +unsafe extern "C" { + pub fn vec_dft_sub( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const VEC_ZNX_DFT, + a_size: u64, + b: *const VEC_ZNX_DFT, + b_size: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} +unsafe extern "C" { + pub fn vec_znx_idft( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + tmp: *mut u8, + ); +} +unsafe extern "C" { + pub fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; +} +unsafe extern "C" { + pub fn vec_znx_idft_tmp_a( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a_dft: *mut VEC_ZNX_DFT, + a_size: u64, + ); +} diff --git a/base2k/src/ffi/vmp.rs b/base2k/src/ffi/vmp.rs new file mode 100644 index 0000000..e202e4c --- /dev/null +++ b/base2k/src/ffi/vmp.rs @@ -0,0 +1,96 @@ +use crate::ffi::module::MODULE; +use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vmp_pmat_t { + _unused: [u8; 0], +} +pub type VMP_PMAT = vmp_pmat_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT; +} +unsafe extern "C" { + pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_tmp_bytes( + module: *const MODULE, + res_size: u64, + a_size: u64, + nrows: u64, + ncols: u64, + ) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_to_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( + module: *const MODULE, + res_size: u64, + a_size: u64, + nrows: u64, + ncols: u64, + ) -> u64; +} + +unsafe extern "C" { + pub fn vmp_prepare_contiguous( + module: *const MODULE, + pmat: *mut VMP_PMAT, + mat: *const i64, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} +unsafe extern "C" { + pub fn vmp_prepare_dblptr( + module: *const MODULE, + pmat: *mut VMP_PMAT, + mat: *mut *const i64, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} +unsafe extern "C" { + pub unsafe fn vmp_prepare_contiguous_tmp_bytes( + module: *const MODULE, + nrows: u64, + ncols: u64, + ) -> u64; +} diff --git a/base2k/src/ffi/znx.rs b/base2k/src/ffi/znx.rs new file mode 100644 index 0000000..24674e8 --- /dev/null +++ b/base2k/src/ffi/znx.rs @@ -0,0 +1,89 @@ +use crate::ffi::module::MODULE; + +unsafe extern "C" { + pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_normalize( + nn: u64, + base_k: u64, + out: *mut i64, + carry_out: *mut i64, + in_: *const i64, + carry_in: *const i64, + ); +} + +unsafe extern "C" { + pub unsafe fn znx_small_single_product( + module: *const MODULE, + res: *mut i64, + a: *const i64, + b: *const i64, + tmp: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 02129e6..91c6054 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -5,9 +5,7 @@ dead_code, improper_ctypes )] -pub mod bindings { - include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -} +pub mod ffi; pub mod module; #[allow(unused_imports)] @@ -55,3 +53,15 @@ pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } } + +pub fn cast_mut_u8_to_mut_f64_slice(data: &mut [u8]) -> &mut [f64] { + let ptr: *mut f64 = data.as_mut_ptr() as *mut f64; + let len: usize = data.len() / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} + +pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] { + let ptr: *const f64 = data.as_mut_ptr() as *const f64; + let len: usize = data.len() / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts(ptr, len) } +} diff --git a/base2k/src/module.rs b/base2k/src/module.rs index cfbdd83..2ce179a 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,7 +1,4 @@ -use crate::bindings::{ - module_info_t, new_module_info, svp_ppol_t, vec_znx_bigcoeff_t, vec_znx_dft_t, MODULE, -}; - +use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE}; use crate::GALOISGENERATOR; pub type MODULETYPE = u8; @@ -56,28 +53,9 @@ impl Module { (gal_el as i64) * gen.signum() } -} -pub struct SvpPPol(pub *mut svp_ppol_t); - -pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize); - -impl VecZnxBig { - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft(self.0 as *mut vec_znx_dft_t, self.1) - } - pub fn limbs(&self) -> usize { - self.1 - } -} - -pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize); - -impl VecZnxDft { - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig(self.0 as *mut vec_znx_bigcoeff_t, self.1) - } - pub fn limbs(&self) -> usize { - self.1 + pub fn delete(self) { + unsafe { delete_module_info(self.0) } + drop(self); } } diff --git a/base2k/src/scalar_vector_product.rs b/base2k/src/scalar_vector_product.rs index 83d157a..548ad04 100644 --- a/base2k/src/scalar_vector_product.rs +++ b/base2k/src/scalar_vector_product.rs @@ -1,6 +1,15 @@ -use crate::bindings::{new_svp_ppol, svp_apply_dft, svp_prepare}; +use crate::ffi::svp::{delete_svp_ppol, new_svp_ppol, svp_apply_dft, svp_ppol_t, svp_prepare}; use crate::scalar::Scalar; -use crate::{Module, SvpPPol, VecZnx, VecZnxDft}; +use crate::{Module, VecZnx, VecZnxDft}; + +pub struct SvpPPol(pub *mut svp_ppol_t); + +impl SvpPPol { + pub fn delete(self) { + unsafe { delete_svp_ppol(self.0) }; + let _ = drop(self); + } +} impl Module { // Prepares a scalar polynomial (1 limb) for a scalar x vector product. diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index d2413b0..7d4c9a8 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,7 +1,7 @@ -use crate::bindings::{ +use crate::cast_mut_u8_to_mut_i64_slice; +use crate::ffi::znx::{ znx_automorphism_i64, znx_automorphism_inplace_i64, znx_normalize, znx_zero_i64_ref, }; -use crate::cast_mut_u8_to_mut_i64_slice; use crate::module::Module; use itertools::izip; use rand_distr::{Distribution, Normal}; @@ -98,7 +98,7 @@ impl VecZnx { &mut self.data[i * self.n..(i + 1) * self.n] } - pub fn set_i64(&mut self, data: &[i64], log_max: usize) { + pub fn from_i64(&mut self, data: &[i64], log_max: usize) { let size: usize = min(data.len(), self.n()); let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k); @@ -131,7 +131,7 @@ impl VecZnx { } } - pub fn set_single_i64(&mut self, i: usize, value: i64, log_max: usize) { + pub fn from_i64_single(&mut self, i: usize, value: i64, log_max: usize) { assert!(i < self.n()); let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k); @@ -187,7 +187,7 @@ impl VecZnx { } } - pub fn get_i64(&self, data: &mut [i64]) { + pub fn to_i64(&self, data: &mut [i64]) { assert!( data.len() >= self.n, "invalid data: data.len()={} < self.n()={}", @@ -210,7 +210,7 @@ impl VecZnx { }) } - pub fn get_single_i64(&self, i: usize) -> i64 { + pub fn to_i64_single(&self, i: usize) -> i64 { assert!(i < self.n()); let mut res: i64 = self.data[i]; let rem: usize = self.log_base2k - (self.log_q % self.log_base2k); @@ -366,9 +366,9 @@ mod tests { have.iter_mut() .enumerate() .for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2); - a.set_i64(&have, 10); + a.from_i64(&have, 10); let mut want = vec![i64::default(); n]; - a.get_i64(&mut want); + a.to_i64(&mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); } @@ -385,11 +385,11 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.set_i64(&have, 63); + a.from_i64(&have, 63); //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); let mut want = vec![i64::default(); n]; //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); - a.get_i64(&mut want); + a.to_i64(&mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } #[test] @@ -405,7 +405,7 @@ mod tests { .next_u64n(u64::MAX, u64::MAX) .wrapping_sub(u64::MAX / 2 + 1) as i64; }); - a.set_i64(&have, 63); + a.from_i64(&have, 63); let mut carry: Vec = vec![u8::default(); n * 8]; a.normalize(&mut carry); @@ -414,7 +414,7 @@ mod tests { .iter() .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); let mut want = vec![i64::default(); n]; - a.get_i64(&mut want); + a.to_i64(&mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } } diff --git a/base2k/src/vec_znx_arithmetic.rs b/base2k/src/vec_znx_arithmetic.rs index 203a37f..5e8bb08 100644 --- a/base2k/src/vec_znx_arithmetic.rs +++ b/base2k/src/vec_znx_arithmetic.rs @@ -1,4 +1,4 @@ -use crate::bindings::{ +use crate::ffi::vec_znx::{ vec_znx_add, vec_znx_automorphism, vec_znx_negate, vec_znx_rotate, vec_znx_sub, }; use crate::{Module, VecZnx}; diff --git a/base2k/src/vec_znx_big_arithmetic.rs b/base2k/src/vec_znx_big_arithmetic.rs index dcca5e3..84d0396 100644 --- a/base2k/src/vec_znx_big_arithmetic.rs +++ b/base2k/src/vec_znx_big_arithmetic.rs @@ -1,8 +1,28 @@ -use crate::bindings::{ - new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism, vec_znx_big_normalize_base2k, - vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a, +use crate::ffi::vec_znx_big::{ + delete_vec_znx_big, new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism, + vec_znx_big_normalize_base2k, vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a, + vec_znx_bigcoeff_t, }; -use crate::{Module, VecZnx, VecZnxBig}; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; + +use crate::{Module, VecZnx, VecZnxDft}; + +pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize); + +impl VecZnxBig { + pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { + VecZnxDft(self.0 as *mut vec_znx_dft_t, self.1) + } + pub fn limbs(&self) -> usize { + self.1 + } + pub fn delete(self) { + unsafe { + delete_vec_znx_big(self.0); + } + drop(self); + } +} impl Module { // Allocates a vector Z[X]/(X^N+1) that stores not normalized values. diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 4296300..c5f3bec 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,5 +1,25 @@ -use crate::bindings::{new_vec_znx_dft, vec_znx_idft, vec_znx_idft_tmp_a, vec_znx_idft_tmp_bytes}; -use crate::module::{Module, VecZnxBig, VecZnxDft}; +use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t; +use crate::ffi::vec_znx_dft::{ + delete_vec_znx_dft, new_vec_znx_dft, vec_znx_dft_t, vec_znx_idft, vec_znx_idft_tmp_a, + vec_znx_idft_tmp_bytes, +}; +use crate::{Module, VecZnxBig}; + +pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize); + +impl VecZnxDft { + pub fn as_vec_znx_big(&mut self) -> VecZnxBig { + VecZnxBig(self.0 as *mut vec_znx_bigcoeff_t, self.1) + } + pub fn limbs(&self) -> usize { + self.1 + } + + pub fn delete(self) { + unsafe { delete_vec_znx_dft(self.0) }; + drop(self); + } +} impl Module { // Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. diff --git a/base2k/src/vector_matrix_product.rs b/base2k/src/vector_matrix_product.rs index 9b0f89b..b97d721 100644 --- a/base2k/src/vector_matrix_product.rs +++ b/base2k/src/vector_matrix_product.rs @@ -1,25 +1,96 @@ -use crate::bindings::{ - new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, +use crate::ffi::vmp::{ + delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, - vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr, + vmp_prepare_contiguous_tmp_bytes, }; use crate::{Module, VecZnx, VecZnxDft}; -pub struct VmpPMat(pub *mut vmp_pmat_t, pub usize, pub usize); +pub struct VmpPMat { + pub data: *mut vmp_pmat_t, + pub rows: usize, + pub cols: usize, + pub n: usize, +} impl VmpPMat { + pub fn data(&self) -> *mut vmp_pmat_t { + self.data + } + pub fn rows(&self) -> usize { - self.1 + self.rows } pub fn cols(&self) -> usize { - self.2 + self.cols + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn as_f64(&self) -> &[f64] { + let ptr: *const f64 = self.data as *const f64; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { &std::slice::from_raw_parts(ptr, len) } + } + + pub fn get_addr(&self, row: usize, col: usize, blk: usize) -> &[f64] { + let nrows: usize = self.rows(); + let ncols: usize = self.cols(); + if col == (ncols - 1) && (ncols & 1 == 1) { + &self.as_f64()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] + } else { + &self.as_f64()[blk * nrows * ncols * 8 + + (col / 2) * (2 * nrows) * 8 + + row * 2 * 8 + + (col % 2) * 8..] + } + } + + pub fn at(&self, row: usize, col: usize) -> Vec { + //assert!(row <= self.rows && col <= self.cols); + + let mut res: Vec = vec![f64::default(); self.n]; + + if self.n < 8 { + res.copy_from_slice( + &self.as_f64()[(row + col * self.rows()) * self.n() + ..(row + col * self.rows()) * (self.n() + 1)], + ); + } else { + (0..self.n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_addr(row, col, blk)[..8]); + }); + } + + res + } + + pub fn at_mut(&self, row: usize, col: usize) -> &mut [f64] { + assert!(row <= self.rows && col <= self.cols); + let idx: usize = col * (self.n / 2 * self.rows) + row * (self.n >> 1); + let ptr: *mut f64 = self.data as *mut f64; + let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); + unsafe { &mut std::slice::from_raw_parts_mut(ptr, len)[idx..idx + self.n] } + } + + pub fn delete(self) { + unsafe { delete_vmp_pmat(self.data) }; + drop(self); } } impl Module { pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { - unsafe { VmpPMat(new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols) } + unsafe { + VmpPMat { + data: new_vmp_pmat(self.0, rows as u64, cols as u64), + rows, + cols, + n: self.n(), + } + } } pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { @@ -30,10 +101,10 @@ impl Module { unsafe { vmp_prepare_contiguous( self.0, - b.0, + b.data(), a.as_ptr(), - b.1 as u64, - b.2 as u64, + b.rows() as u64, + b.cols() as u64, buf.as_mut_ptr(), ); } @@ -66,7 +137,7 @@ impl Module { a.as_ptr(), a.limbs() as u64, a.n() as u64, - b.0, + b.data(), b.rows() as u64, b.cols() as u64, buf.as_mut_ptr(), @@ -106,7 +177,7 @@ impl Module { c.limbs() as u64, a.0, a.limbs() as u64, - b.0, + b.data(), b.rows() as u64, b.cols() as u64, buf.as_mut_ptr(), @@ -122,7 +193,7 @@ impl Module { b.limbs() as u64, b.0, b.limbs() as u64, - a.0, + a.data(), a.rows() as u64, a.cols() as u64, buf.as_mut_ptr(),