rewrote all bindings, removed dependency on binding generation

This commit is contained in:
Jean-Philippe Bossuat
2025-01-30 17:34:57 +01:00
parent a7af4d6d1f
commit d3a8d20647
25 changed files with 1040 additions and 189 deletions

View File

@@ -12,9 +12,6 @@ rand_core = {workspace = true}
sampling = { path = "../sampling" }
utils = { path = "../utils" }
[build-dependencies]
bindgen ="0.71.1"
[[bench]]
name = "fft"
harness = false

View File

@@ -1,8 +1,4 @@
use base2k::bindings::{
new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp,
reim_fft_precomp_get_buffer, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp,
reim_ifft_precomp_get_buffer,
};
use base2k::ffi::reim::*;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use std::ffi::c_void;

View File

@@ -1,11 +1,16 @@
use bindgen;
use std::env;
use std::fs;
//use bindgen;
//use std::env;
//use std::fs;
use std::path::absolute;
use std::path::PathBuf;
use std::time::SystemTime;
//use std::path::PathBuf;
//use std::time::SystemTime;
fn main() {
/*
[build-dependencies]
bindgen ="0.71.1"
// Path to the C header file
let header_paths: [&str; 2] = [
"spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h",
@@ -43,6 +48,7 @@ fn main() {
.write_to_file(&bindings_file)
.expect("Couldn't write bindings!");
}
*/
println!(
"cargo:rustc-link-search=native={}",

View File

@@ -1,8 +1,4 @@
use base2k::bindings::{
new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp_get_buffer,
reim_fftvec_mul_simple, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp_get_buffer,
reim_to_znx64_simple,
};
use base2k::ffi::reim::*;
use std::ffi::c_void;
use std::time::Instant;

View File

@@ -48,7 +48,7 @@ fn main() {
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
// m
m.set_i64(&want, 4);
m.from_i64(&want, 4);
m.normalize(&mut carry);
// buf_big <- m - buf_big
@@ -73,7 +73,7 @@ fn main() {
// have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n];
res.get_i64(&mut have);
res.to_i64(&mut have);
let scale: f64 = (1 << log_scale) as f64;
izip!(want.iter(), have.iter())

View File

@@ -1,7 +1,7 @@
use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64};
fn main() {
let log_n = 4;
let log_n = 5;
let n = 1 << log_n;
let module: Module = Module::new::<FFT64>(n);
@@ -14,7 +14,7 @@ fn main() {
// Maximum size of the byte scratch needed
let tmp_bytes: usize = module.vmp_prepare_contiguous_tmp_bytes(rows, cols)
| module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols);
| module.vmp_apply_dft_tmp_bytes(limbs, limbs, rows, cols);
let mut buf: Vec<u8> = vec![0; tmp_bytes];
@@ -22,7 +22,7 @@ fn main() {
a_values[1] = (1 << log_base2k) + 1;
let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q);
a.set_i64(&a_values, 32);
a.from_i64(&a_values, 32);
a.normalize(&mut buf);
(0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i)));
@@ -30,9 +30,14 @@ fn main() {
let mut b_mat: Matrix3D<i64> = Matrix3D::new(rows, cols, n);
(0..rows).for_each(|i| {
b_mat.at_mut(i, i)[0] = ((1 << 15) + 1) << 15;
(0..cols).for_each(|j| {
b_mat.at_mut(i, j)[0] = (i * cols + j) as i64;
b_mat.at_mut(i, j)[0] = (i * cols + j) as i64;
})
});
//b_mat.data.iter_mut().enumerate().for_each(|(i, xi)| *xi = i as i64);
println!();
(0..rows).for_each(|i| {
(0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j)));
@@ -42,6 +47,13 @@ fn main() {
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf);
(0..cols).for_each(|i| {
(0..rows).for_each(|j| println!("{} {}: {:?}", i, j, vmp_pmat.at(i, j)));
println!();
});
println!("{:?}", vmp_pmat.as_f64());
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
@@ -52,9 +64,13 @@ fn main() {
module.vec_znx_big_normalize(&mut res, &c_big, &mut buf);
let mut values_res: Vec<i64> = vec![i64::default(); n];
res.get_i64(&mut values_res);
res.to_i64(&mut values_res);
(0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i)));
module.delete();
c_dft.delete();
vmp_pmat.delete();
println!("{:?}", values_res)
}

7
base2k/src/ffi/cnv.rs Normal file
View File

@@ -0,0 +1,7 @@
pub type CNV_PVEC_L = cnv_pvec_l_t;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cnv_pvec_r_t {
_unused: [u8; 0],
}
pub type CNV_PVEC_R = cnv_pvec_r_t;

8
base2k/src/ffi/mod.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod module;
pub mod reim;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
pub mod znx;

20
base2k/src/ffi/module.rs Normal file
View File

@@ -0,0 +1,20 @@
pub struct module_info_t {
_unused: [u8; 0],
}
pub type module_type_t = ::std::os::raw::c_uint;
pub const module_type_t_FFT64: module_type_t = 0;
pub const module_type_t_NTT120: module_type_t = 1;
pub use self::module_type_t as MODULE_TYPE;
pub type MODULE = module_info_t;
unsafe extern "C" {
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
}
unsafe extern "C" {
pub unsafe fn delete_module_info(module_info: *mut MODULE);
}
unsafe extern "C" {
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
}

242
base2k/src/ffi/reim.rs Normal file
View File

@@ -0,0 +1,242 @@
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_fft_precomp {
_unused: [u8; 0],
}
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_ifft_precomp {
_unused: [u8; 0],
}
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_mul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_addmul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
unsafe extern "C" {
pub fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fft_precomp_get_buffer(
tables: *const REIM_FFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
}
unsafe extern "C" {
pub fn new_reim_fft_buffer(m: u32) -> *mut f64;
}
unsafe extern "C" {
pub fn delete_reim_fft_buffer(buffer: *mut f64);
}
unsafe extern "C" {
pub fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
}
unsafe extern "C" {
pub fn reim_ifft_precomp_get_buffer(
tables: *const REIM_IFFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
}
unsafe extern "C" {
pub fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fftvec_mul(
tables: *const REIM_FFTVEC_MUL_PRECOMP,
r: *mut f64,
a: *const f64,
b: *const f64,
);
}
unsafe extern "C" {
pub fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fftvec_addmul(
tables: *const REIM_FFTVEC_ADDMUL_PRECOMP,
r: *mut f64,
a: *const f64,
b: *const f64,
);
}
unsafe extern "C" {
pub fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_znx32(
tables: *const REIM_FROM_ZNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
}
unsafe extern "C" {
pub fn reim_from_znx64(
tables: *const REIM_FROM_ZNX64_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i64,
);
}
unsafe extern "C" {
pub fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_znx64_simple(
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
a: *const i64,
);
}
unsafe extern "C" {
pub fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_tnx32(
tables: *const REIM_FROM_TNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
}
unsafe extern "C" {
pub fn new_reim_to_tnx32_precomp(
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_tnx32(
tables: *const REIM_TO_TNX32_PRECOMP,
r: *mut i32,
a: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn new_reim_to_tnx_precomp(
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub fn new_reim_to_znx64_precomp(
m: u32,
divisor: f64,
log2bound: u32,
) -> *mut REIM_TO_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_znx64(
precomp: *const REIM_TO_ZNX64_PRECOMP,
r: *mut i64,
a: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn reim_to_znx64_simple(
m: u32,
divisor: f64,
log2bound: u32,
r: *mut i64,
a: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub fn reim_fftvec_mul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn reim_fftvec_addmul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn reim_from_znx32_simple(
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
x: *const i32,
);
}
unsafe extern "C" {
pub fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub fn reim_to_tnx32_simple(
m: u32,
divisor: f64,
log2overhead: u32,
r: *mut i32,
x: *const ::std::os::raw::c_void,
);
}

35
base2k/src/ffi/svp.rs Normal file
View File

@@ -0,0 +1,35 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct svp_ppol_t {
_unused: [u8; 0],
}
pub type SVP_PPOL = svp_ppol_t;
unsafe extern "C" {
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
}
unsafe extern "C" {
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
}
unsafe extern "C" {
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
ppol: *const SVP_PPOL,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}

101
base2k/src/ffi/vec_znx.rs Normal file
View File

@@ -0,0 +1,101 @@
use crate::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn vec_znx_add(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_automorphism(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_negate(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_rotate(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_sub(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
}
unsafe extern "C" {
pub fn vec_znx_copy(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -0,0 +1,158 @@
use crate::ffi::module::MODULE;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_bigcoeff_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_BIG = vec_znx_bigcoeff_t;
unsafe extern "C" {
pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
}
unsafe extern "C" {
pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
}
unsafe extern "C" {
pub fn vec_znx_big_add(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_add_small(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_add_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small_b(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub fn vec_znx_big_automorphism(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_rotate(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_range_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_range_begin: u64,
a_range_xend: u64,
a_range_step: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -0,0 +1,77 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_dft_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_DFT = vec_znx_dft_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
}
unsafe extern "C" {
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
}
unsafe extern "C" {
pub fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
}
unsafe extern "C" {
pub fn vec_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub fn vec_dft_sub(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_idft(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub fn vec_znx_idft_tmp_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *mut VEC_ZNX_DFT,
a_size: u64,
);
}

96
base2k/src/ffi/vmp.rs Normal file
View File

@@ -0,0 +1,96 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vmp_pmat_t {
_unused: [u8; 0],
}
pub type VMP_PMAT = vmp_pmat_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
}
unsafe extern "C" {
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes(
module: *const MODULE,
res_size: u64,
a_size: u64,
nrows: u64,
ncols: u64,
) -> u64;
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE,
res_size: u64,
a_size: u64,
nrows: u64,
ncols: u64,
) -> u64;
}
unsafe extern "C" {
pub fn vmp_prepare_contiguous(
module: *const MODULE,
pmat: *mut VMP_PMAT,
mat: *const i64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vmp_prepare_dblptr(
module: *const MODULE,
pmat: *mut VMP_PMAT,
mat: *mut *const i64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_contiguous_tmp_bytes(
module: *const MODULE,
nrows: u64,
ncols: u64,
) -> u64;
}

89
base2k/src/ffi/znx.rs Normal file
View File

@@ -0,0 +1,89 @@
use crate::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_normalize(
nn: u64,
base_k: u64,
out: *mut i64,
carry_out: *mut i64,
in_: *const i64,
carry_in: *const i64,
);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product(
module: *const MODULE,
res: *mut i64,
a: *const i64,
b: *const i64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -5,9 +5,7 @@
dead_code,
improper_ctypes
)]
pub mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
pub mod ffi;
pub mod module;
#[allow(unused_imports)]
@@ -55,3 +53,15 @@ pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] {
let len: usize = data.len() / std::mem::size_of::<i64>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
pub fn cast_mut_u8_to_mut_f64_slice(data: &mut [u8]) -> &mut [f64] {
let ptr: *mut f64 = data.as_mut_ptr() as *mut f64;
let len: usize = data.len() / std::mem::size_of::<f64>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
pub fn cast_u8_to_f64_slice(data: &mut [u8]) -> &[f64] {
let ptr: *const f64 = data.as_mut_ptr() as *const f64;
let len: usize = data.len() / std::mem::size_of::<f64>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}

View File

@@ -1,7 +1,4 @@
use crate::bindings::{
module_info_t, new_module_info, svp_ppol_t, vec_znx_bigcoeff_t, vec_znx_dft_t, MODULE,
};
use crate::ffi::module::{delete_module_info, module_info_t, new_module_info, MODULE};
use crate::GALOISGENERATOR;
pub type MODULETYPE = u8;
@@ -56,28 +53,9 @@ impl Module {
(gal_el as i64) * gen.signum()
}
}
pub struct SvpPPol(pub *mut svp_ppol_t);
pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize);
impl VecZnxBig {
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
VecZnxDft(self.0 as *mut vec_znx_dft_t, self.1)
}
pub fn limbs(&self) -> usize {
self.1
}
}
pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize);
impl VecZnxDft {
pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
VecZnxBig(self.0 as *mut vec_znx_bigcoeff_t, self.1)
}
pub fn limbs(&self) -> usize {
self.1
pub fn delete(self) {
unsafe { delete_module_info(self.0) }
drop(self);
}
}

View File

@@ -1,6 +1,15 @@
use crate::bindings::{new_svp_ppol, svp_apply_dft, svp_prepare};
use crate::ffi::svp::{delete_svp_ppol, new_svp_ppol, svp_apply_dft, svp_ppol_t, svp_prepare};
use crate::scalar::Scalar;
use crate::{Module, SvpPPol, VecZnx, VecZnxDft};
use crate::{Module, VecZnx, VecZnxDft};
pub struct SvpPPol(pub *mut svp_ppol_t);
impl SvpPPol {
pub fn delete(self) {
unsafe { delete_svp_ppol(self.0) };
let _ = drop(self);
}
}
impl Module {
// Prepares a scalar polynomial (1 limb) for a scalar x vector product.

View File

@@ -1,7 +1,7 @@
use crate::bindings::{
use crate::cast_mut_u8_to_mut_i64_slice;
use crate::ffi::znx::{
znx_automorphism_i64, znx_automorphism_inplace_i64, znx_normalize, znx_zero_i64_ref,
};
use crate::cast_mut_u8_to_mut_i64_slice;
use crate::module::Module;
use itertools::izip;
use rand_distr::{Distribution, Normal};
@@ -98,7 +98,7 @@ impl VecZnx {
&mut self.data[i * self.n..(i + 1) * self.n]
}
pub fn set_i64(&mut self, data: &[i64], log_max: usize) {
pub fn from_i64(&mut self, data: &[i64], log_max: usize) {
let size: usize = min(data.len(), self.n());
let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k);
@@ -131,7 +131,7 @@ impl VecZnx {
}
}
pub fn set_single_i64(&mut self, i: usize, value: i64, log_max: usize) {
pub fn from_i64_single(&mut self, i: usize, value: i64, log_max: usize) {
assert!(i < self.n());
let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k);
@@ -187,7 +187,7 @@ impl VecZnx {
}
}
pub fn get_i64(&self, data: &mut [i64]) {
pub fn to_i64(&self, data: &mut [i64]) {
assert!(
data.len() >= self.n,
"invalid data: data.len()={} < self.n()={}",
@@ -210,7 +210,7 @@ impl VecZnx {
})
}
pub fn get_single_i64(&self, i: usize) -> i64 {
pub fn to_i64_single(&self, i: usize) -> i64 {
assert!(i < self.n());
let mut res: i64 = self.data[i];
let rem: usize = self.log_base2k - (self.log_q % self.log_base2k);
@@ -366,9 +366,9 @@ mod tests {
have.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i as i64) - (n as i64) / 2);
a.set_i64(&have, 10);
a.from_i64(&have, 10);
let mut want = vec![i64::default(); n];
a.get_i64(&mut want);
a.to_i64(&mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b));
}
@@ -385,11 +385,11 @@ mod tests {
.next_u64n(u64::MAX, u64::MAX)
.wrapping_sub(u64::MAX / 2 + 1) as i64;
});
a.set_i64(&have, 63);
a.from_i64(&have, 63);
//(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i)));
let mut want = vec![i64::default(); n];
//(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i)));
a.get_i64(&mut want);
a.to_i64(&mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
}
#[test]
@@ -405,7 +405,7 @@ mod tests {
.next_u64n(u64::MAX, u64::MAX)
.wrapping_sub(u64::MAX / 2 + 1) as i64;
});
a.set_i64(&have, 63);
a.from_i64(&have, 63);
let mut carry: Vec<u8> = vec![u8::default(); n * 8];
a.normalize(&mut carry);
@@ -414,7 +414,7 @@ mod tests {
.iter()
.for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half));
let mut want = vec![i64::default(); n];
a.get_i64(&mut want);
a.to_i64(&mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
}
}

View File

@@ -1,4 +1,4 @@
use crate::bindings::{
use crate::ffi::vec_znx::{
vec_znx_add, vec_znx_automorphism, vec_znx_negate, vec_znx_rotate, vec_znx_sub,
};
use crate::{Module, VecZnx};

View File

@@ -1,8 +1,28 @@
use crate::bindings::{
new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism, vec_znx_big_normalize_base2k,
vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a,
use crate::ffi::vec_znx_big::{
delete_vec_znx_big, new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism,
vec_znx_big_normalize_base2k, vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a,
vec_znx_bigcoeff_t,
};
use crate::{Module, VecZnx, VecZnxBig};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{Module, VecZnx, VecZnxDft};
pub struct VecZnxBig(pub *mut vec_znx_bigcoeff_t, pub usize);
impl VecZnxBig {
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
VecZnxDft(self.0 as *mut vec_znx_dft_t, self.1)
}
pub fn limbs(&self) -> usize {
self.1
}
pub fn delete(self) {
unsafe {
delete_vec_znx_big(self.0);
}
drop(self);
}
}
impl Module {
// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.

View File

@@ -1,5 +1,25 @@
use crate::bindings::{new_vec_znx_dft, vec_znx_idft, vec_znx_idft_tmp_a, vec_znx_idft_tmp_bytes};
use crate::module::{Module, VecZnxBig, VecZnxDft};
use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t;
use crate::ffi::vec_znx_dft::{
delete_vec_znx_dft, new_vec_znx_dft, vec_znx_dft_t, vec_znx_idft, vec_znx_idft_tmp_a,
vec_znx_idft_tmp_bytes,
};
use crate::{Module, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft_t, pub usize);
impl VecZnxDft {
pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
VecZnxBig(self.0 as *mut vec_znx_bigcoeff_t, self.1)
}
pub fn limbs(&self) -> usize {
self.1
}
pub fn delete(self) {
unsafe { delete_vec_znx_dft(self.0) };
drop(self);
}
}
impl Module {
// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.

View File

@@ -1,25 +1,96 @@
use crate::bindings::{
new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft,
use crate::ffi::vmp::{
delete_vmp_pmat, new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft,
vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous,
vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr,
vmp_prepare_contiguous_tmp_bytes,
};
use crate::{Module, VecZnx, VecZnxDft};
pub struct VmpPMat(pub *mut vmp_pmat_t, pub usize, pub usize);
pub struct VmpPMat {
pub data: *mut vmp_pmat_t,
pub rows: usize,
pub cols: usize,
pub n: usize,
}
impl VmpPMat {
pub fn data(&self) -> *mut vmp_pmat_t {
self.data
}
pub fn rows(&self) -> usize {
self.1
self.rows
}
pub fn cols(&self) -> usize {
self.2
self.cols
}
pub fn n(&self) -> usize {
self.n
}
pub fn as_f64(&self) -> &[f64] {
let ptr: *const f64 = self.data as *const f64;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<f64>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
pub fn get_addr(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows();
let ncols: usize = self.cols();
if col == (ncols - 1) && (ncols & 1 == 1) {
&self.as_f64()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
} else {
&self.as_f64()[blk * nrows * ncols * 8
+ (col / 2) * (2 * nrows) * 8
+ row * 2 * 8
+ (col % 2) * 8..]
}
}
pub fn at(&self, row: usize, col: usize) -> Vec<f64> {
//assert!(row <= self.rows && col <= self.cols);
let mut res: Vec<f64> = vec![f64::default(); self.n];
if self.n < 8 {
res.copy_from_slice(
&self.as_f64()[(row + col * self.rows()) * self.n()
..(row + col * self.rows()) * (self.n() + 1)],
);
} else {
(0..self.n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.get_addr(row, col, blk)[..8]);
});
}
res
}
pub fn at_mut(&self, row: usize, col: usize) -> &mut [f64] {
assert!(row <= self.rows && col <= self.cols);
let idx: usize = col * (self.n / 2 * self.rows) + row * (self.n >> 1);
let ptr: *mut f64 = self.data as *mut f64;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<f64>();
unsafe { &mut std::slice::from_raw_parts_mut(ptr, len)[idx..idx + self.n] }
}
pub fn delete(self) {
unsafe { delete_vmp_pmat(self.data) };
drop(self);
}
}
impl Module {
pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat {
unsafe { VmpPMat(new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols) }
unsafe {
VmpPMat {
data: new_vmp_pmat(self.0, rows as u64, cols as u64),
rows,
cols,
n: self.n(),
}
}
}
pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
@@ -30,10 +101,10 @@ impl Module {
unsafe {
vmp_prepare_contiguous(
self.0,
b.0,
b.data(),
a.as_ptr(),
b.1 as u64,
b.2 as u64,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
);
}
@@ -66,7 +137,7 @@ impl Module {
a.as_ptr(),
a.limbs() as u64,
a.n() as u64,
b.0,
b.data(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
@@ -106,7 +177,7 @@ impl Module {
c.limbs() as u64,
a.0,
a.limbs() as u64,
b.0,
b.data(),
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
@@ -122,7 +193,7 @@ impl Module {
b.limbs() as u64,
b.0,
b.limbs() as u64,
a.0,
a.data(),
a.rows() as u64,
a.cols() as u64,
buf.as_mut_ptr(),