mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
added Added vmp_extract_row, vmp_extract_row_dft, vmp_extract_tmp_bytes, vmp_prepare_row_dft
-
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
|
||||
use crate::{
|
||||
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
|
||||
};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -23,7 +26,7 @@ pub struct VmpPMat {
|
||||
/// The ring degree of each [VecZnxDft].
|
||||
n: usize,
|
||||
|
||||
backend: MODULETYPE,
|
||||
backend: BACKEND,
|
||||
}
|
||||
|
||||
impl Infos for VmpPMat {
|
||||
@@ -59,7 +62,7 @@ impl VmpPMat {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn borrowed(&self) -> bool{
|
||||
pub fn borrowed(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
@@ -167,7 +170,7 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -179,6 +182,35 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -375,6 +407,60 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_cols: usize,
|
||||
@@ -489,3 +575,52 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
|
||||
VecZnxOps, VmpPMat, VmpPMatOps,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vmp_prepare_row_dft() {
|
||||
let module: Module = Module::new(32, crate::BACKEND::FFT64);
|
||||
let vpmat_rows: usize = 4;
|
||||
let vpmat_cols: usize = 5;
|
||||
let log_base2k: usize = 8;
|
||||
let mut a: VecZnx = module.new_vec_znx(vpmat_cols);
|
||||
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
|
||||
let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
|
||||
let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
|
||||
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
|
||||
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||
|
||||
for row_i in 0..vpmat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
|
||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
||||
|
||||
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
|
||||
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
|
||||
|
||||
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
|
||||
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
|
||||
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
|
||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
||||
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
|
||||
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
|
||||
}
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user