prototype trait for Elem<T> + new ciphertext for VmPPmat

This commit is contained in:
Jean-Philippe Bossuat
2025-02-18 11:04:13 +01:00
parent fdc2f3ac42
commit d486e89761
21 changed files with 767 additions and 811 deletions

View File

@@ -1,5 +1,5 @@
use crate::ffi::vmp;
use crate::{Infos, Module, VecZnx, VecZnxApi, VecZnxDft};
use crate::{Infos, Module, VecZnxApi, VecZnxDft};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -15,7 +15,7 @@ pub struct VmpPMat {
pub data: *mut vmp::vmp_pmat_t,
/// The number of [VecZnxDft].
pub rows: usize,
/// The number of limbs in each [VecZnxDft].
/// The number of cols in each [VecZnxDft].
pub cols: usize,
/// The ring degree of each [VecZnxDft].
pub n: usize,
@@ -86,7 +86,7 @@ pub trait VmpPMatOps {
/// # Arguments
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `cols`: number of cols (number of limbs of each [VecZnxDft]).
/// * `cols`: number of cols (number of cols of each [VecZnxDft]).
fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
@@ -153,15 +153,17 @@ pub trait VmpPMatOps {
/// vecznx.push(module.new_vec_znx(cols));
/// });
///
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(&self, b: &mut VmpPMat, a: &Vec<T>, buf: &mut [u8]);
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
///
@@ -175,7 +177,7 @@ pub trait VmpPMatOps {
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
@@ -188,31 +190,25 @@ pub trait VmpPMatOps {
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf);
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &T,
row_i: usize,
tmp_bytes: &mut [u8],
);
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
/// # Arguments
///
/// * `c_limbs`: number of limbs of the output [VecZnxDft].
/// * `a_limbs`: number of limbs of the input [VecZnx].
/// * `c_cols`: number of cols of the output [VecZnxDft].
/// * `a_cols`: number of cols of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
@@ -223,8 +219,8 @@ pub trait VmpPMatOps {
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` limbs.
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
@@ -249,18 +245,18 @@ pub trait VmpPMatOps {
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let limbs: usize = 5;
/// let cols: usize = 5;
///
/// let rows: usize = limbs;
/// let cols: usize = limbs + 1;
/// let c_limbs: usize = cols;
/// let a_limbs: usize = limbs;
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_limbs, a_limbs, rows, cols);
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a: VecZnx = module.new_vec_znx(limbs);
/// let a: VecZnx = module.new_vec_znx(cols);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
///
@@ -280,14 +276,14 @@ pub trait VmpPMatOps {
///
/// # Arguments
///
/// * `c_limbs`: number of limbs of the output [VecZnxDft].
/// * `a_limbs`: number of limbs of the input [VecZnxDft].
/// * `c_cols`: number of cols of the output [VecZnxDft].
/// * `a_cols`: number of cols of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
@@ -299,8 +295,8 @@ pub trait VmpPMatOps {
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` limbs.
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
@@ -325,18 +321,18 @@ pub trait VmpPMatOps {
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let limbs: usize = 5;
/// let cols: usize = 5;
///
/// let rows: usize = limbs;
/// let cols: usize = limbs + 1;
/// let c_limbs: usize = cols;
/// let a_limbs: usize = limbs;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_limbs, a_limbs, rows, cols);
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let c_cols: usize = cols;
/// let a_cols: usize = cols;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf);
///
@@ -354,8 +350,8 @@ pub trait VmpPMatOps {
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` limbs.
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
@@ -379,17 +375,17 @@ pub trait VmpPMatOps {
/// let n = 1024;
///
/// let module: Module = Module::new::<FFT64>(n);
/// let limbs: usize = 5;
/// let cols: usize = 5;
///
/// let rows: usize = limbs;
/// let cols: usize = limbs + 1;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols);
/// let rows: usize = cols;
/// let cols: usize = cols + 1;
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols);
///
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
/// let a: VecZnx = module.new_vec_znx(limbs);
/// let a: VecZnx = module.new_vec_znx(cols);
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
///
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf);
///
/// c_dft.free();
@@ -428,12 +424,7 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &Vec<T>,
buf: &mut [u8],
) {
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
unsafe {
vmp::vmp_prepare_dblptr(
@@ -447,13 +438,7 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_row<T: VecZnxApi + Infos>(
&self,
b: &mut VmpPMat,
a: &T,
row_i: usize,
buf: &mut [u8],
) {
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
unsafe {
vmp::vmp_prepare_row(
self.0,
@@ -469,16 +454,16 @@ impl VmpPMatOps for Module {
fn vmp_apply_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.0,
c_limbs as u64,
a_limbs as u64,
c_cols as u64,
a_cols as u64,
rows as u64,
cols as u64,
) as usize
@@ -496,9 +481,9 @@ impl VmpPMatOps for Module {
vmp::vmp_apply_dft(
self.0,
c.0,
c.limbs() as u64,
c.cols() as u64,
a.as_ptr(),
a.limbs() as u64,
a.cols() as u64,
a.n() as u64,
b.data(),
b.rows() as u64,
@@ -510,16 +495,16 @@ impl VmpPMatOps for Module {
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.0,
c_limbs as u64,
a_limbs as u64,
c_cols as u64,
a_cols as u64,
rows as u64,
cols as u64,
) as usize
@@ -531,9 +516,9 @@ impl VmpPMatOps for Module {
vmp::vmp_apply_dft_to_dft(
self.0,
c.0,
c.limbs() as u64,
c.cols() as u64,
a.0,
a.limbs() as u64,
a.cols() as u64,
b.data(),
b.rows() as u64,
b.cols() as u64,
@@ -547,9 +532,9 @@ impl VmpPMatOps for Module {
vmp::vmp_apply_dft_to_dft(
self.0,
b.0,
b.limbs() as u64,
b.cols() as u64,
b.0,
b.limbs() as u64,
b.cols() as u64,
a.data(),
a.rows() as u64,
a.cols() as u64,