added Added vmp_extract_row, vmp_extract_row_dft, vmp_extract_tmp_bytes, vmp_prepare_row_dft

-
This commit is contained in:
Jean-Philippe Bossuat
2025-04-16 11:31:58 +02:00
parent 4c1dbc70e5
commit 89369dcdf9
18 changed files with 293 additions and 181 deletions

View File

@@ -1,6 +1,9 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxDft, MODULETYPE};
use crate::{
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -23,7 +26,7 @@ pub struct VmpPMat {
/// The ring degree of each [VecZnxDft].
n: usize,
backend: MODULETYPE,
backend: BACKEND,
}
impl Infos for VmpPMat {
@@ -59,7 +62,7 @@ impl VmpPMat {
self.ptr
}
pub fn borrowed(&self) -> bool{
pub fn borrowed(&self) -> bool {
self.data.len() == 0
}
@@ -167,7 +170,7 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
///
/// # Arguments
///
@@ -179,6 +182,35 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
///
/// # Arguments
///
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
/// # Arguments
@@ -375,6 +407,60 @@ impl VmpPMatOps for Module {
}
}
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row(
self.ptr,
b.ptr as *mut vec_znx_big_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_prepare_row_dft(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.ptr as *const vec_znx_dft_t,
row_i as u64,
b.rows() as u64,
b.cols() as u64,
);
}
}
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_apply_dft_tmp_bytes(
&self,
res_cols: usize,
@@ -489,3 +575,52 @@ impl VmpPMatOps for Module {
}
}
}
#[cfg(test)]
mod tests {
use crate::{
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
VecZnxOps, VmpPMat, VmpPMatOps,
};
use sampling::source::Source;
#[test]
fn vmp_prepare_row_dft() {
let module: Module = Module::new(32, crate::BACKEND::FFT64);
let vpmat_rows: usize = 4;
let vpmat_cols: usize = 5;
let log_base2k: usize = 8;
let mut a: VecZnx = module.new_vec_znx(vpmat_cols);
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
let mut a_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
let mut b_big: VecZnxBig = module.new_vec_znx_big(vpmat_cols);
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(vpmat_cols);
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut tmp_bytes: Vec<u8> =
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
}
module.free();
}
}