mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
added Added vmp_extract_row, vmp_extract_row_dft, vmp_extract_tmp_bytes, vmp_prepare_row_dft
-
This commit is contained in:
@@ -1,36 +0,0 @@
|
||||
/*
|
||||
[build-dependencies]
|
||||
bindgen ="0.71.1"
|
||||
|
||||
//use bindgen;
|
||||
//use std::env;
|
||||
//use std::fs;
|
||||
//use std::path::PathBuf;
|
||||
//use std::time::SystemTime;
|
||||
|
||||
// Path to the C header file
|
||||
let header_paths: [&str; 2] = [
|
||||
"spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h",
|
||||
"spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h",
|
||||
];
|
||||
|
||||
let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let bindings_file = out_path.join("bindings.rs");
|
||||
|
||||
let mut builder: bindgen::Builder = bindgen::Builder::default();
|
||||
for header in header_paths {
|
||||
builder = builder.header(header);
|
||||
}
|
||||
|
||||
let bindings = builder
|
||||
.generate_comments(false) // Optional: includes comments in bindings
|
||||
.generate_inline_functions(true) // Optional: includes inline functions
|
||||
.generate()
|
||||
.expect("Unable to generate bindings");
|
||||
|
||||
// Write the bindings to the OUT_DIR
|
||||
bindings
|
||||
.write_to_file(&bindings_file)
|
||||
.expect("Couldn't write bindings!");
|
||||
|
||||
*/
|
||||
@@ -59,40 +59,40 @@ pub struct reim_to_znx64_precomp {
|
||||
}
|
||||
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fft_precomp_get_buffer(
|
||||
pub unsafe fn reim_fft_precomp_get_buffer(
|
||||
tables: *const REIM_FFT_PRECOMP,
|
||||
buffer_index: u32,
|
||||
) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn delete_reim_fft_buffer(buffer: *mut f64);
|
||||
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
|
||||
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_ifft_precomp_get_buffer(
|
||||
pub unsafe fn reim_ifft_precomp_get_buffer(
|
||||
tables: *const REIM_IFFT_PRECOMP,
|
||||
buffer_index: u32,
|
||||
) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fftvec_mul(
|
||||
pub unsafe fn reim_fftvec_mul(
|
||||
tables: *const REIM_FFTVEC_MUL_PRECOMP,
|
||||
r: *mut f64,
|
||||
a: *const f64,
|
||||
@@ -100,10 +100,10 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fftvec_addmul(
|
||||
pub unsafe fn reim_fftvec_addmul(
|
||||
tables: *const REIM_FFTVEC_ADDMUL_PRECOMP,
|
||||
r: *mut f64,
|
||||
a: *const f64,
|
||||
@@ -111,27 +111,30 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||
pub unsafe fn new_reim_from_znx32_precomp(
|
||||
m: u32,
|
||||
log2bound: u32,
|
||||
) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_znx32(
|
||||
pub unsafe fn reim_from_znx32(
|
||||
tables: *const REIM_FROM_ZNX32_PRECOMP,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const i32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_znx64(
|
||||
pub unsafe fn reim_from_znx64(
|
||||
tables: *const REIM_FROM_ZNX64_PRECOMP,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const i64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_znx64_simple(
|
||||
pub unsafe fn reim_from_znx64_simple(
|
||||
m: u32,
|
||||
log2bound: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
@@ -139,58 +142,64 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_tnx32(
|
||||
pub unsafe fn reim_from_tnx32(
|
||||
tables: *const REIM_FROM_TNX32_PRECOMP,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const i32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_to_tnx32_precomp(
|
||||
pub unsafe fn new_reim_to_tnx32_precomp(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2overhead: u32,
|
||||
) -> *mut REIM_TO_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_tnx32(
|
||||
pub unsafe fn reim_to_tnx32(
|
||||
tables: *const REIM_TO_TNX32_PRECOMP,
|
||||
r: *mut i32,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_to_tnx_precomp(
|
||||
pub unsafe fn new_reim_to_tnx_precomp(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2overhead: u32,
|
||||
) -> *mut REIM_TO_TNX_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
|
||||
pub unsafe fn reim_to_tnx_simple(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2overhead: u32,
|
||||
r: *mut f64,
|
||||
a: *const f64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_reim_to_znx64_precomp(
|
||||
pub unsafe fn new_reim_to_znx64_precomp(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2bound: u32,
|
||||
) -> *mut REIM_TO_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_znx64(
|
||||
pub unsafe fn reim_to_znx64(
|
||||
precomp: *const REIM_TO_ZNX64_PRECOMP,
|
||||
r: *mut i64,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_znx64_simple(
|
||||
pub unsafe fn reim_to_znx64_simple(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2bound: u32,
|
||||
@@ -199,13 +208,13 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fftvec_mul_simple(
|
||||
pub unsafe fn reim_fftvec_mul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
@@ -213,7 +222,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_fftvec_addmul_simple(
|
||||
pub unsafe fn reim_fftvec_addmul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
@@ -221,7 +230,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_znx32_simple(
|
||||
pub unsafe fn reim_from_znx32_simple(
|
||||
m: u32,
|
||||
log2bound: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
@@ -229,10 +238,10 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn reim_to_tnx32_simple(
|
||||
pub unsafe fn reim_to_tnx32_simple(
|
||||
m: u32,
|
||||
divisor: f64,
|
||||
log2overhead: u32,
|
||||
|
||||
@@ -2,10 +2,10 @@ use crate::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_bigcoeff_t {
|
||||
pub struct vec_znx_big_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_BIG = vec_znx_bigcoeff_t;
|
||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
|
||||
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
|
||||
|
||||
#[repr(C)]
|
||||
@@ -103,6 +104,39 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row_dft(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const VEC_ZNX_DFT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::GALOISGENERATOR;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum MODULETYPE {
|
||||
pub enum BACKEND {
|
||||
FFT64,
|
||||
NTT120,
|
||||
}
|
||||
@@ -11,17 +11,17 @@ pub enum MODULETYPE {
|
||||
pub struct Module {
|
||||
pub ptr: *mut MODULE,
|
||||
pub n: usize,
|
||||
pub backend: MODULETYPE,
|
||||
pub backend: BACKEND,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
// Instantiates a new module.
|
||||
pub fn new(n: usize, module_type: MODULETYPE) -> Self {
|
||||
pub fn new(n: usize, module_type: BACKEND) -> Self {
|
||||
unsafe {
|
||||
let module_type_u32: u32;
|
||||
match module_type {
|
||||
MODULETYPE::FFT64 => module_type_u32 = 0,
|
||||
MODULETYPE::NTT120 => module_type_u32 = 1,
|
||||
BACKEND::FFT64 => module_type_u32 = 0,
|
||||
BACKEND::NTT120 => module_type_u32 = 1,
|
||||
}
|
||||
let m: *mut module_info_t = new_module_info(n as u64, module_type_u32);
|
||||
if m.is_null() {
|
||||
@@ -35,7 +35,7 @@ impl Module {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> MODULETYPE {
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
|
||||
@@ -540,31 +540,6 @@ impl VecZnxOps for Module {
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `a` is greater than `a.cols()`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
|
||||
/// use itertools::izip;
|
||||
///
|
||||
/// let n: usize = 8; // polynomial degree
|
||||
/// let module = Module::new(n, MODULETYPE::FFT64);
|
||||
/// let mut a: VecZnx = module.new_vec_znx(2);
|
||||
/// let mut b: VecZnx = module.new_vec_znx(2);
|
||||
/// let mut c: VecZnx = module.new_vec_znx(2);
|
||||
///
|
||||
/// (0..a.cols()).for_each(|i|{
|
||||
/// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{
|
||||
/// *x = i as i64
|
||||
/// })
|
||||
/// });
|
||||
///
|
||||
/// module.vec_znx_automorphism(-1, &mut b, &a, 1); // X^i -> X^(-i)
|
||||
/// let col = c.at_mut(0);
|
||||
/// (1..col.len()).for_each(|i|{
|
||||
/// col[n-i] = -(i as i64)
|
||||
/// });
|
||||
/// izip!(b.raw().iter(), c.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
/// ```
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) {
|
||||
debug_assert_eq!(a.n(), self.n());
|
||||
debug_assert_eq!(b.n(), self.n());
|
||||
@@ -594,30 +569,6 @@ impl VecZnxOps for Module {
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `cols` is greater than `self.cols()`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
|
||||
/// use itertools::izip;
|
||||
///
|
||||
/// let n: usize = 8; // polynomial degree
|
||||
/// let module = Module::new(n, MODULETYPE::FFT64);
|
||||
/// let mut a: VecZnx = VecZnx::new(n, 2);
|
||||
/// let mut b: VecZnx = VecZnx::new(n, 2);
|
||||
///
|
||||
/// (0..a.cols()).for_each(|i|{
|
||||
/// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{
|
||||
/// *x = i as i64
|
||||
/// })
|
||||
/// });
|
||||
///
|
||||
/// module.vec_znx_automorphism_inplace(-1, &mut a, 1); // X^i -> X^(-i)
|
||||
/// let col = b.at_mut(0);
|
||||
/// (1..col.len()).for_each(|i|{
|
||||
/// col[n-i] = -(i as i64)
|
||||
/// });
|
||||
/// izip!(a.raw().iter(), b.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
/// ```
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) {
|
||||
debug_assert_eq!(a.n(), self.n());
|
||||
debug_assert!(a.cols() >= a_cols);
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use crate::ffi::vec_znx_big::{self, vec_znx_bigcoeff_t};
|
||||
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
|
||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
||||
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, BACKEND};
|
||||
|
||||
pub struct VecZnxBig {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub backend: MODULETYPE,
|
||||
pub backend: BACKEND,
|
||||
}
|
||||
|
||||
impl VecZnxBig {
|
||||
@@ -62,9 +62,19 @@ impl VecZnxBig {
|
||||
self.cols
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> MODULETYPE {
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is cols * n.
|
||||
pub fn raw<T>(&self, module: &Module) -> &[T] {
|
||||
let ptr: *const T = self.ptr as *const T;
|
||||
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps {
|
||||
@@ -162,12 +172,12 @@ impl VecZnxBigOps for Module {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_sub_small_a(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
@@ -177,12 +187,12 @@ impl VecZnxBigOps for Module {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_sub_small_a(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_bigcoeff_t,
|
||||
c.ptr as *mut vec_znx_big_t,
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
@@ -192,9 +202,9 @@ impl VecZnxBigOps for Module {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_bigcoeff_t,
|
||||
c.ptr as *mut vec_znx_big_t,
|
||||
c.cols() as u64,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
@@ -207,9 +217,9 @@ impl VecZnxBigOps for Module {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
@@ -246,7 +256,7 @@ impl VecZnxBigOps for Module {
|
||||
b.as_mut_ptr(),
|
||||
b.cols() as u64,
|
||||
b.n() as u64,
|
||||
a.ptr as *mut vec_znx_bigcoeff_t,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
@@ -284,7 +294,7 @@ impl VecZnxBigOps for Module {
|
||||
res.as_mut_ptr(),
|
||||
res.cols() as u64,
|
||||
res.n() as u64,
|
||||
a.ptr as *mut vec_znx_bigcoeff_t,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a_range_begin as u64,
|
||||
a_range_xend as u64,
|
||||
a_range_step as u64,
|
||||
@@ -298,9 +308,9 @@ impl VecZnxBigOps for Module {
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
self.ptr,
|
||||
gal_el,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *mut vec_znx_bigcoeff_t,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
@@ -311,9 +321,9 @@ impl VecZnxBigOps for Module {
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
self.ptr,
|
||||
gal_el,
|
||||
a.ptr as *mut vec_znx_bigcoeff_t,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
a.ptr as *mut vec_znx_bigcoeff_t,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t;
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
||||
use crate::{alloc_aligned, VecZnx};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE};
|
||||
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND};
|
||||
|
||||
pub struct VecZnxDft {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub backend: MODULETYPE,
|
||||
pub backend: BACKEND,
|
||||
}
|
||||
|
||||
impl VecZnxDft {
|
||||
@@ -69,7 +69,7 @@ impl VecZnxDft {
|
||||
self.cols
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> MODULETYPE {
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
@@ -133,17 +133,17 @@ pub trait VecZnxDftOps {
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize);
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
|
||||
|
||||
fn vec_znx_idft(
|
||||
&self,
|
||||
b: &mut VecZnxBig,
|
||||
a: &mut VecZnxDft,
|
||||
a_limbs: usize,
|
||||
a: &VecZnxDft,
|
||||
a_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
);
|
||||
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize);
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
|
||||
}
|
||||
|
||||
impl VecZnxDftOps for Module {
|
||||
@@ -177,20 +177,20 @@ impl VecZnxDftOps for Module {
|
||||
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
|
||||
debug_assert!(
|
||||
b.cols() >= a_limbs,
|
||||
"invalid c_vector: b_vector.cols()={} < a_limbs={}",
|
||||
b.cols() >= a_cols,
|
||||
"invalid c_vector: b_vector.cols()={} < a_cols={}",
|
||||
b.cols(),
|
||||
a_limbs
|
||||
a_cols
|
||||
);
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a_limbs as u64,
|
||||
a_cols as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -226,7 +226,7 @@ impl VecZnxDftOps for Module {
|
||||
fn vec_znx_idft(
|
||||
&self,
|
||||
b: &mut VecZnxBig,
|
||||
a: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
a_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
@@ -243,7 +243,7 @@ impl VecZnxDftOps for Module {
|
||||
a_cols
|
||||
);
|
||||
debug_assert!(
|
||||
tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
|
||||
tmp_bytes.len() >= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
|
||||
@@ -255,9 +255,9 @@ impl VecZnxDftOps for Module {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_bigcoeff_t,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a_cols as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
|
||||
use crate::{
|
||||
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
|
||||
};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -23,7 +26,7 @@ pub struct VmpPMat {
|
||||
/// The ring degree of each [VecZnxDft].
|
||||
n: usize,
|
||||
|
||||
backend: MODULETYPE,
|
||||
backend: BACKEND,
|
||||
}
|
||||
|
||||
impl Infos for VmpPMat {
|
||||
@@ -59,7 +62,7 @@ impl VmpPMat {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn borrowed(&self) -> bool{
|
||||
pub fn borrowed(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
@@ -167,7 +170,7 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -179,6 +182,35 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -375,6 +407,60 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_cols: usize,
|
||||
@@ -489,3 +575,52 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
|
||||
VecZnxOps, VmpPMat, VmpPMatOps,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vmp_prepare_row_dft() {
|
||||
let module: Module = Module::new(32, crate::BACKEND::FFT64);
|
||||
let vpmat_rows: usize = 4;
|
||||
let vpmat_cols: usize = 5;
|
||||
let log_base2k: usize = 8;
|
||||
let mut a: VecZnx = module.new_vec_znx(vpmat_cols);
|
||||
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
|
||||
let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
|
||||
let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
|
||||
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
|
||||
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||
|
||||
for row_i in 0..vpmat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
|
||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
||||
|
||||
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
|
||||
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
|
||||
|
||||
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
|
||||
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
|
||||
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
|
||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
||||
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
|
||||
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
|
||||
}
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user