added Added vmp_extract_row, vmp_extract_row_dft, vmp_extract_tmp_bytes, vmp_prepare_row_dft

-
This commit is contained in:
Jean-Philippe Bossuat
2025-04-16 11:31:58 +02:00
parent 4c1dbc70e5
commit 89369dcdf9
18 changed files with 293 additions and 181 deletions

11
.vscode/settings.json vendored
View File

@@ -57,6 +57,15 @@
"xloctime": "cpp",
"xmemory": "cpp",
"xtr1common": "cpp",
"vec_znx_arithmetic_private.h": "c"
"vec_znx_arithmetic_private.h": "c",
"reim4_arithmetic.h": "c",
"array": "c",
"string_view": "c"
},
"github.copilot.enable": {
"*": false,
"plaintext": false,
"markdown": false,
"scminput": false
}
}

View File

@@ -1,6 +1,6 @@
use base2k::{
alloc_aligned, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx,
VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, MODULETYPE,
VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND,
};
use itertools::izip;
use sampling::source::Source;
@@ -11,7 +11,7 @@ fn main() {
let cols: usize = 3;
let msg_cols: usize = 2;
let log_scale: usize = msg_cols * log_base2k - 5;
let module: Module = Module::new(n, MODULETYPE::FFT64);
let module: Module = Module::new(n, BACKEND::FFT64);
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());

View File

@@ -1,13 +1,13 @@
use base2k::{
alloc_aligned, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, MODULETYPE,
VecZnxDftOps, VecZnxOps, VecZnxVec, VmpPMat, VmpPMatOps, BACKEND,
};
fn main() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module = Module::new(n, MODULETYPE::FFT64);
let module: Module = Module::new(n, BACKEND::FFT64);
let log_base2k: usize = 15;
let cols: usize = 5;
let log_k: usize = log_base2k * cols - 5;

View File

@@ -1,36 +0,0 @@
/*
[build-dependencies]
bindgen ="0.71.1"
//use bindgen;
//use std::env;
//use std::fs;
//use std::path::PathBuf;
//use std::time::SystemTime;
// Path to the C header file
let header_paths: [&str; 2] = [
"spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h",
"spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h",
];
let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
let bindings_file = out_path.join("bindings.rs");
let mut builder: bindgen::Builder = bindgen::Builder::default();
for header in header_paths {
builder = builder.header(header);
}
let bindings = builder
.generate_comments(false) // Optional: includes comments in bindings
.generate_inline_functions(true) // Optional: includes inline functions
.generate()
.expect("Unable to generate bindings");
// Write the bindings to the OUT_DIR
bindings
.write_to_file(&bindings_file)
.expect("Couldn't write bindings!");
*/

View File

@@ -59,40 +59,40 @@ pub struct reim_to_znx64_precomp {
}
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
unsafe extern "C" {
pub fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fft_precomp_get_buffer(
pub unsafe fn reim_fft_precomp_get_buffer(
tables: *const REIM_FFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
}
unsafe extern "C" {
pub fn new_reim_fft_buffer(m: u32) -> *mut f64;
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
}
unsafe extern "C" {
pub fn delete_reim_fft_buffer(buffer: *mut f64);
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
}
unsafe extern "C" {
pub fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
}
unsafe extern "C" {
pub fn reim_ifft_precomp_get_buffer(
pub unsafe fn reim_ifft_precomp_get_buffer(
tables: *const REIM_IFFT_PRECOMP,
buffer_index: u32,
) -> *mut f64;
}
unsafe extern "C" {
pub fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fftvec_mul(
pub unsafe fn reim_fftvec_mul(
tables: *const REIM_FFTVEC_MUL_PRECOMP,
r: *mut f64,
a: *const f64,
@@ -100,10 +100,10 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
}
unsafe extern "C" {
pub fn reim_fftvec_addmul(
pub unsafe fn reim_fftvec_addmul(
tables: *const REIM_FFTVEC_ADDMUL_PRECOMP,
r: *mut f64,
a: *const f64,
@@ -111,27 +111,30 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
pub unsafe fn new_reim_from_znx32_precomp(
m: u32,
log2bound: u32,
) -> *mut REIM_FROM_ZNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_znx32(
pub unsafe fn reim_from_znx32(
tables: *const REIM_FROM_ZNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
}
unsafe extern "C" {
pub fn reim_from_znx64(
pub unsafe fn reim_from_znx64(
tables: *const REIM_FROM_ZNX64_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i64,
);
}
unsafe extern "C" {
pub fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_znx64_simple(
pub unsafe fn reim_from_znx64_simple(
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
@@ -139,58 +142,64 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_from_tnx32(
pub unsafe fn reim_from_tnx32(
tables: *const REIM_FROM_TNX32_PRECOMP,
r: *mut ::std::os::raw::c_void,
a: *const i32,
);
}
unsafe extern "C" {
pub fn new_reim_to_tnx32_precomp(
pub unsafe fn new_reim_to_tnx32_precomp(
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX32_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_tnx32(
pub unsafe fn reim_to_tnx32(
tables: *const REIM_TO_TNX32_PRECOMP,
r: *mut i32,
a: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn new_reim_to_tnx_precomp(
pub unsafe fn new_reim_to_tnx_precomp(
m: u32,
divisor: f64,
log2overhead: u32,
) -> *mut REIM_TO_TNX_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
pub unsafe fn reim_to_tnx_simple(
m: u32,
divisor: f64,
log2overhead: u32,
r: *mut f64,
a: *const f64,
);
}
unsafe extern "C" {
pub fn new_reim_to_znx64_precomp(
pub unsafe fn new_reim_to_znx64_precomp(
m: u32,
divisor: f64,
log2bound: u32,
) -> *mut REIM_TO_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub fn reim_to_znx64(
pub unsafe fn reim_to_znx64(
precomp: *const REIM_TO_ZNX64_PRECOMP,
r: *mut i64,
a: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub fn reim_to_znx64_simple(
pub unsafe fn reim_to_znx64_simple(
m: u32,
divisor: f64,
log2bound: u32,
@@ -199,13 +208,13 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub fn reim_fftvec_mul_simple(
pub unsafe fn reim_fftvec_mul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
@@ -213,7 +222,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn reim_fftvec_addmul_simple(
pub unsafe fn reim_fftvec_addmul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
@@ -221,7 +230,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn reim_from_znx32_simple(
pub unsafe fn reim_from_znx32_simple(
m: u32,
log2bound: u32,
r: *mut ::std::os::raw::c_void,
@@ -229,10 +238,10 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub fn reim_to_tnx32_simple(
pub unsafe fn reim_to_tnx32_simple(
m: u32,
divisor: f64,
log2overhead: u32,

View File

@@ -2,10 +2,10 @@ use crate::ffi::module::MODULE;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_bigcoeff_t {
pub struct vec_znx_big_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_BIG = vec_znx_bigcoeff_t;
pub type VEC_ZNX_BIG = vec_znx_big_t;
unsafe extern "C" {
pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;

View File

@@ -1,4 +1,5 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
#[repr(C)]
@@ -103,6 +104,39 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_row_dft(
module: *const MODULE,
pmat: *mut VMP_PMAT,
row: *const VEC_ZNX_DFT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_extract_row_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
pmat: *const VMP_PMAT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_extract_row(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
pmat: *const VMP_PMAT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}

View File

@@ -3,7 +3,7 @@ use crate::GALOISGENERATOR;
#[derive(Copy, Clone)]
#[repr(u8)]
pub enum MODULETYPE {
pub enum BACKEND {
FFT64,
NTT120,
}
@@ -11,17 +11,17 @@ pub enum MODULETYPE {
pub struct Module {
pub ptr: *mut MODULE,
pub n: usize,
pub backend: MODULETYPE,
pub backend: BACKEND,
}
impl Module {
// Instantiates a new module.
pub fn new(n: usize, module_type: MODULETYPE) -> Self {
pub fn new(n: usize, module_type: BACKEND) -> Self {
unsafe {
let module_type_u32: u32;
match module_type {
MODULETYPE::FFT64 => module_type_u32 = 0,
MODULETYPE::NTT120 => module_type_u32 = 1,
BACKEND::FFT64 => module_type_u32 = 0,
BACKEND::NTT120 => module_type_u32 = 1,
}
let m: *mut module_info_t = new_module_info(n as u64, module_type_u32);
if m.is_null() {
@@ -35,7 +35,7 @@ impl Module {
}
}
pub fn backend(&self) -> MODULETYPE {
pub fn backend(&self) -> BACKEND {
self.backend
}

View File

@@ -540,31 +540,6 @@ impl VecZnxOps for Module {
/// # Panics
///
/// The method will panic if the argument `a` is greater than `a.cols()`.
///
/// # Example
/// ```
/// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
/// use itertools::izip;
///
/// let n: usize = 8; // polynomial degree
/// let module = Module::new(n, MODULETYPE::FFT64);
/// let mut a: VecZnx = module.new_vec_znx(2);
/// let mut b: VecZnx = module.new_vec_znx(2);
/// let mut c: VecZnx = module.new_vec_znx(2);
///
/// (0..a.cols()).for_each(|i|{
/// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{
/// *x = i as i64
/// })
/// });
///
/// module.vec_znx_automorphism(-1, &mut b, &a, 1); // X^i -> X^(-i)
/// let col = c.at_mut(0);
/// (1..col.len()).for_each(|i|{
/// col[n-i] = -(i as i64)
/// });
/// izip!(b.raw().iter(), c.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ```
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) {
debug_assert_eq!(a.n(), self.n());
debug_assert_eq!(b.n(), self.n());
@@ -594,30 +569,6 @@ impl VecZnxOps for Module {
/// # Panics
///
/// The method will panic if the argument `cols` is greater than `self.cols()`.
///
/// # Example
/// ```
/// use base2k::{Module, MODULETYPE, VecZnx, Encoding, Infos, VecZnxOps};
/// use itertools::izip;
///
/// let n: usize = 8; // polynomial degree
/// let module = Module::new(n, MODULETYPE::FFT64);
/// let mut a: VecZnx = VecZnx::new(n, 2);
/// let mut b: VecZnx = VecZnx::new(n, 2);
///
/// (0..a.cols()).for_each(|i|{
/// a.at_mut(i).iter_mut().enumerate().for_each(|(i, x)|{
/// *x = i as i64
/// })
/// });
///
/// module.vec_znx_automorphism_inplace(-1, &mut a, 1); // X^i -> X^(-i)
/// let col = b.at_mut(0);
/// (1..col.len()).for_each(|i|{
/// col[n-i] = -(i as i64)
/// });
/// izip!(a.raw().iter(), b.raw().iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ```
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) {
debug_assert_eq!(a.n(), self.n());
debug_assert!(a.cols() >= a_cols);

View File

@@ -1,12 +1,12 @@
use crate::ffi::vec_znx_big::{self, vec_znx_bigcoeff_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, BACKEND};
pub struct VecZnxBig {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub backend: MODULETYPE,
pub backend: BACKEND,
}
impl VecZnxBig {
@@ -62,9 +62,19 @@ impl VecZnxBig {
self.cols
}
pub fn backend(&self) -> MODULETYPE {
pub fn backend(&self) -> BACKEND {
self.backend
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw<T>(&self, module: &Module) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
}
pub trait VecZnxBigOps {
@@ -162,12 +172,12 @@ impl VecZnxBigOps for Module {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
)
}
@@ -177,12 +187,12 @@ impl VecZnxBigOps for Module {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
c.ptr as *mut vec_znx_bigcoeff_t,
c.ptr as *mut vec_znx_big_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
)
}
@@ -192,9 +202,9 @@ impl VecZnxBigOps for Module {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
c.ptr as *mut vec_znx_bigcoeff_t,
c.ptr as *mut vec_znx_big_t,
c.cols() as u64,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
@@ -207,9 +217,9 @@ impl VecZnxBigOps for Module {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
@@ -246,7 +256,7 @@ impl VecZnxBigOps for Module {
b.as_mut_ptr(),
b.cols() as u64,
b.n() as u64,
a.ptr as *mut vec_znx_bigcoeff_t,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
@@ -284,7 +294,7 @@ impl VecZnxBigOps for Module {
res.as_mut_ptr(),
res.cols() as u64,
res.n() as u64,
a.ptr as *mut vec_znx_bigcoeff_t,
a.ptr as *mut vec_znx_big_t,
a_range_begin as u64,
a_range_xend as u64,
a_range_step as u64,
@@ -298,9 +308,9 @@ impl VecZnxBigOps for Module {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_bigcoeff_t,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
);
}
@@ -311,9 +321,9 @@ impl VecZnxBigOps for Module {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
a.ptr as *mut vec_znx_bigcoeff_t,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
a.ptr as *mut vec_znx_bigcoeff_t,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
);
}

View File

@@ -1,15 +1,15 @@
use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t;
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{alloc_aligned, VecZnx};
use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE};
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND};
pub struct VecZnxDft {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub backend: MODULETYPE,
pub backend: BACKEND,
}
impl VecZnxDft {
@@ -69,7 +69,7 @@ impl VecZnxDft {
self.cols
}
pub fn backend(&self) -> MODULETYPE {
pub fn backend(&self) -> BACKEND {
self.backend
}
@@ -133,17 +133,17 @@ pub trait VecZnxDftOps {
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
fn vec_znx_idft(
&self,
b: &mut VecZnxBig,
a: &mut VecZnxDft,
a_limbs: usize,
a: &VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
}
impl VecZnxDftOps for Module {
@@ -177,20 +177,20 @@ impl VecZnxDftOps for Module {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
debug_assert!(
b.cols() >= a_limbs,
"invalid c_vector: b_vector.cols()={} < a_limbs={}",
b.cols() >= a_cols,
"invalid c_vector: b_vector.cols()={} < a_cols={}",
b.cols(),
a_limbs
a_cols
);
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a_limbs as u64,
a_cols as u64,
)
}
}
@@ -226,7 +226,7 @@ impl VecZnxDftOps for Module {
fn vec_znx_idft(
&self,
b: &mut VecZnxBig,
a: &mut VecZnxDft,
a: &VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
@@ -243,7 +243,7 @@ impl VecZnxDftOps for Module {
a_cols
);
debug_assert!(
tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
tmp_bytes.len() >= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
@@ -255,9 +255,9 @@ impl VecZnxDftOps for Module {
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
a.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
tmp_bytes.as_mut_ptr(),
)

View File

@@ -1,6 +1,9 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
use crate::{
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -23,7 +26,7 @@ pub struct VmpPMat {
/// The ring degree of each [VecZnxDft].
n: usize,
backend: MODULETYPE,
backend: BACKEND,
}
impl Infos for VmpPMat {
@@ -59,7 +62,7 @@ impl VmpPMat {
self.ptr
}
pub fn borrowed(&self) -> bool{
pub fn borrowed(&self) -> bool {
self.data.len() == 0
}
@@ -167,7 +170,7 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
///
/// # Arguments
///
@@ -179,6 +182,35 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
///
/// # Arguments
///
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
/// # Arguments
@@ -375,6 +407,60 @@ impl VmpPMatOps for Module {
}
}
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row(
self.ptr,
b.ptr as *mut vec_znx_big_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_prepare_row_dft(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.ptr as *const vec_znx_dft_t,
row_i as u64,
b.rows() as u64,
b.cols() as u64,
);
}
}
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_apply_dft_tmp_bytes(
&self,
res_cols: usize,
@@ -489,3 +575,52 @@ impl VmpPMatOps for Module {
}
}
}
#[cfg(test)]
mod tests {
use crate::{
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
VecZnxOps, VmpPMat, VmpPMatOps,
};
use sampling::source::Source;
#[test]
fn vmp_prepare_row_dft() {
let module: Module = Module::new(32, crate::BACKEND::FFT64);
let vpmat_rows: usize = 4;
let vpmat_cols: usize = 5;
let log_base2k: usize = 8;
let mut a: VecZnx = module.new_vec_znx(vpmat_cols);
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut tmp_bytes: Vec<u8> =
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
}
module.free();
}
}

View File

@@ -1,5 +1,5 @@
use base2k::{
Infos, MODULETYPE, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
Infos, BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps,
VmpPMat, alloc_aligned_u8,
};
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
@@ -36,7 +36,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
for log_n in 10..11 {
let params_lit: ParametersLiteral = ParametersLiteral {
backend: MODULETYPE::FFT64,
backend: BACKEND::FFT64,
log_n: log_n,
log_q: 32,
log_p: 0,

View File

@@ -10,7 +10,7 @@ use sampling::source::Source;
fn main() {
let params_lit: ParametersLiteral = ParametersLiteral {
backend: base2k::MODULETYPE::FFT64,
backend: base2k::BACKEND::FFT64,
log_n: 10,
log_q: 54,
log_p: 0,

View File

@@ -12,7 +12,7 @@ use sampling::source::{Source, new_seed};
fn main() {
let n: usize = 32;
let module: Module = Module::new(n, base2k::MODULETYPE::FFT64);
let module: Module = Module::new(n, base2k::BACKEND::FFT64);
let log_base2k: usize = 16;
let log_k: usize = 32;
let cols: usize = 4;

View File

@@ -97,7 +97,7 @@ mod test {
plaintext::Plaintext,
};
use base2k::{
Infos, MODULETYPE, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
Infos, BACKEND, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
};
use sampling::source::{Source, new_seed};
@@ -110,7 +110,7 @@ mod test {
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: MODULETYPE::FFT64,
backend: BACKEND::FFT64,
log_n: 12,
log_q: q_cols * log_base2k,
log_p: p_cols * log_base2k,

View File

@@ -1,7 +1,7 @@
use base2k::module::{MODULETYPE, Module};
use base2k::module::{BACKEND, Module};
pub struct ParametersLiteral {
pub backend: MODULETYPE,
pub backend: BACKEND,
pub log_n: usize,
pub log_q: usize,
pub log_p: usize,