mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
prototype trait for Elem<T> + new ciphertext for VmPPmat
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::ffi::vmp;
|
||||
use crate::{Infos, Module, VecZnx, VecZnxApi, VecZnxDft};
|
||||
use crate::{Infos, Module, VecZnxApi, VecZnxDft};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -15,7 +15,7 @@ pub struct VmpPMat {
|
||||
pub data: *mut vmp::vmp_pmat_t,
|
||||
/// The number of [VecZnxDft].
|
||||
pub rows: usize,
|
||||
/// The number of limbs in each [VecZnxDft].
|
||||
/// The number of cols in each [VecZnxDft].
|
||||
pub cols: usize,
|
||||
/// The ring degree of each [VecZnxDft].
|
||||
pub n: usize,
|
||||
@@ -86,7 +86,7 @@ pub trait VmpPMatOps {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `cols`: number of cols (number of limbs of each [VecZnxDft]).
|
||||
/// * `cols`: number of cols (number of cols of each [VecZnxDft]).
|
||||
fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat;
|
||||
|
||||
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
|
||||
@@ -153,15 +153,17 @@ pub trait VmpPMatOps {
|
||||
/// vecznx.push(module.new_vec_znx(cols));
|
||||
/// });
|
||||
///
|
||||
/// let slices: Vec<&[i64]> = vecznx.iter().map(|v| v.data.as_slice()).collect();
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
|
||||
/// module.vmp_prepare_dblptr(&mut vmp_pmat, &slices, &mut buf);
|
||||
///
|
||||
/// vmp_pmat.free();
|
||||
/// module.free();
|
||||
/// ```
|
||||
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(&self, b: &mut VmpPMat, a: &Vec<T>, buf: &mut [u8]);
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
|
||||
///
|
||||
@@ -175,7 +177,7 @@ pub trait VmpPMatOps {
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// /// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
|
||||
/// use base2k::{Module, FFT64, VmpPMat, VmpPMatOps, VecZnx, VecZnxApi, VecZnxOps, Free};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
@@ -188,31 +190,25 @@ pub trait VmpPMatOps {
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf);
|
||||
/// module.vmp_prepare_row(&mut vmp_pmat, vecznx.raw(), 0, &mut buf);
|
||||
///
|
||||
/// vmp_pmat.free();
|
||||
/// module.free();
|
||||
/// ```
|
||||
fn vmp_prepare_row<T: VecZnxApi + Infos>(
|
||||
&self,
|
||||
b: &mut VmpPMat,
|
||||
a: &T,
|
||||
row_i: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
);
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_limbs`: number of limbs of the output [VecZnxDft].
|
||||
/// * `a_limbs`: number of limbs of the input [VecZnx].
|
||||
/// * `c_cols`: number of cols of the output [VecZnxDft].
|
||||
/// * `a_cols`: number of cols of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_limbs: usize,
|
||||
a_limbs: usize,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize;
|
||||
@@ -223,8 +219,8 @@ pub trait VmpPMatOps {
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` limbs.
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
@@ -249,18 +245,18 @@ pub trait VmpPMatOps {
|
||||
/// let n = 1024;
|
||||
///
|
||||
/// let module: Module = Module::new::<FFT64>(n);
|
||||
/// let limbs: usize = 5;
|
||||
/// let cols: usize = 5;
|
||||
///
|
||||
/// let rows: usize = limbs;
|
||||
/// let cols: usize = limbs + 1;
|
||||
/// let c_limbs: usize = cols;
|
||||
/// let a_limbs: usize = limbs;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_limbs, a_limbs, rows, cols);
|
||||
/// let rows: usize = cols;
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a: VecZnx = module.new_vec_znx(limbs);
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
||||
///
|
||||
@@ -280,14 +276,14 @@ pub trait VmpPMatOps {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_limbs`: number of limbs of the output [VecZnxDft].
|
||||
/// * `a_limbs`: number of limbs of the input [VecZnxDft].
|
||||
/// * `c_cols`: number of cols of the output [VecZnxDft].
|
||||
/// * `a_cols`: number of cols of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_limbs: usize,
|
||||
a_limbs: usize,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize;
|
||||
@@ -299,8 +295,8 @@ pub trait VmpPMatOps {
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` limbs.
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
@@ -325,18 +321,18 @@ pub trait VmpPMatOps {
|
||||
/// let n = 1024;
|
||||
///
|
||||
/// let module: Module = Module::new::<FFT64>(n);
|
||||
/// let limbs: usize = 5;
|
||||
/// let cols: usize = 5;
|
||||
///
|
||||
/// let rows: usize = limbs;
|
||||
/// let cols: usize = limbs + 1;
|
||||
/// let c_limbs: usize = cols;
|
||||
/// let a_limbs: usize = limbs;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_limbs, a_limbs, rows, cols);
|
||||
/// let rows: usize = cols;
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let c_cols: usize = cols;
|
||||
/// let a_cols: usize = cols;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
|
||||
/// let a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp_pmat, &mut buf);
|
||||
///
|
||||
@@ -354,8 +350,8 @@ pub trait VmpPMatOps {
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` limbs and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` limbs.
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
@@ -379,17 +375,17 @@ pub trait VmpPMatOps {
|
||||
/// let n = 1024;
|
||||
///
|
||||
/// let module: Module = Module::new::<FFT64>(n);
|
||||
/// let limbs: usize = 5;
|
||||
/// let cols: usize = 5;
|
||||
///
|
||||
/// let rows: usize = limbs;
|
||||
/// let cols: usize = limbs + 1;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols);
|
||||
/// let rows: usize = cols;
|
||||
/// let cols: usize = cols + 1;
|
||||
/// let tmp_bytes: usize = module.vmp_apply_dft_to_dft_tmp_bytes(cols, cols, rows, cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||
/// let a: VecZnx = module.new_vec_znx(limbs);
|
||||
/// let a: VecZnx = module.new_vec_znx(cols);
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
///
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
|
||||
/// let mut c_dft: VecZnxDft = module.new_vec_znx_dft(cols);
|
||||
/// module.vmp_apply_dft_to_dft_inplace(&mut c_dft, &vmp_pmat, &mut buf);
|
||||
///
|
||||
/// c_dft.free();
|
||||
@@ -428,12 +424,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_dblptr<T: VecZnxApi + Infos>(
|
||||
&self,
|
||||
b: &mut VmpPMat,
|
||||
a: &Vec<T>,
|
||||
buf: &mut [u8],
|
||||
) {
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
|
||||
unsafe {
|
||||
vmp::vmp_prepare_dblptr(
|
||||
@@ -447,13 +438,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row<T: VecZnxApi + Infos>(
|
||||
&self,
|
||||
b: &mut VmpPMat,
|
||||
a: &T,
|
||||
row_i: usize,
|
||||
buf: &mut [u8],
|
||||
) {
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, buf: &mut [u8]) {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.0,
|
||||
@@ -469,16 +454,16 @@ impl VmpPMatOps for Module {
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_limbs: usize,
|
||||
a_limbs: usize,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.0,
|
||||
c_limbs as u64,
|
||||
a_limbs as u64,
|
||||
c_cols as u64,
|
||||
a_cols as u64,
|
||||
rows as u64,
|
||||
cols as u64,
|
||||
) as usize
|
||||
@@ -496,9 +481,9 @@ impl VmpPMatOps for Module {
|
||||
vmp::vmp_apply_dft(
|
||||
self.0,
|
||||
c.0,
|
||||
c.limbs() as u64,
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
@@ -510,16 +495,16 @@ impl VmpPMatOps for Module {
|
||||
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_limbs: usize,
|
||||
a_limbs: usize,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.0,
|
||||
c_limbs as u64,
|
||||
a_limbs as u64,
|
||||
c_cols as u64,
|
||||
a_cols as u64,
|
||||
rows as u64,
|
||||
cols as u64,
|
||||
) as usize
|
||||
@@ -531,9 +516,9 @@ impl VmpPMatOps for Module {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
c.0,
|
||||
c.limbs() as u64,
|
||||
c.cols() as u64,
|
||||
a.0,
|
||||
a.limbs() as u64,
|
||||
a.cols() as u64,
|
||||
b.data(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
@@ -547,9 +532,9 @@ impl VmpPMatOps for Module {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.0,
|
||||
b.0,
|
||||
b.limbs() as u64,
|
||||
b.cols() as u64,
|
||||
b.0,
|
||||
b.limbs() as u64,
|
||||
b.cols() as u64,
|
||||
a.data(),
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
|
||||
Reference in New Issue
Block a user