mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps,
|
||||
};
|
||||
|
||||
pub trait MatZnxDftAlloc<B> {
|
||||
@@ -36,12 +36,55 @@ pub trait MatZnxDftAlloc<B> {
|
||||
// ) -> MatZnxDft<FFT64>;
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
pub trait MatZnxDftScratch {
|
||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
|
||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
c_size: usize,
|
||||
a_cols: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -58,12 +101,9 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
a: &VecZnx<Data>,
|
||||
scratch: &mut ScratchSpace,
|
||||
scratch: &mut ScratchBorr,
|
||||
);
|
||||
|
||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -78,7 +118,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
a: &MatZnxDft<Data, B>,
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
scratch: &mut ScratchSpace,
|
||||
scratch: &mut ScratchBorr,
|
||||
);
|
||||
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
@@ -101,24 +141,6 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, B>, a: &MatZnxDft<Data, B>, a_row: usize, a_col_in: usize);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
@@ -143,27 +165,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchSpace);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
c_size: usize,
|
||||
a_cols: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchBorr);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
@@ -195,7 +197,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
c: &mut VecZnxDft<DataMut, B>,
|
||||
a: &VecZnxDft<Data, B>,
|
||||
b: &MatZnxDft<Data, B>,
|
||||
scratch: &mut ScratchSpace,
|
||||
scratch: &mut ScratchBorr,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -220,22 +222,70 @@ impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
impl<B: Backend> MatZnxDftScratch for Module<B> {
|
||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||
<Self as VecZnxDftAlloc<FFT64>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
||||
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
||||
}
|
||||
|
||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
||||
+ <Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(self)
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
c_size as u64,
|
||||
a_size as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
c_size: usize,
|
||||
a_cols: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
(c_size * c_cols) as u64,
|
||||
(a_size * a_cols) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
fn vmp_prepare_row(
|
||||
&self,
|
||||
b: &mut MatZnxDft<DataMut, FFT64>,
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
a: &VecZnx<Data>,
|
||||
scratch: &mut ScratchSpace,
|
||||
scratch: &mut ScratchBorr,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -278,17 +328,13 @@ where
|
||||
let a_size: usize = a.size();
|
||||
|
||||
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
||||
let mut a_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, a_size);
|
||||
let (mut a_dft, _) = scratch.tmp_scalar_slice(12);
|
||||
DataMut::from(a_dft);
|
||||
// let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<DataMut, _>(self, cols_out, a_size);
|
||||
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
||||
|
||||
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
|
||||
}
|
||||
|
||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||
self.bytes_of_vec_znx_dft(cols_out, size)
|
||||
+ <Self as VecZnxBigOps<DataMut, Data, FFT64>>::vec_znx_big_normalize_tmp_bytes(self)
|
||||
}
|
||||
|
||||
fn vmp_extract_row(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
@@ -296,7 +342,7 @@ where
|
||||
a: &MatZnxDft<Data, FFT64>,
|
||||
a_row: usize,
|
||||
a_col_in: usize,
|
||||
scratch: &mut ScratchSpace,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -336,9 +382,9 @@ where
|
||||
let size: usize = b.size();
|
||||
|
||||
// let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
|
||||
let mut b_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, size);
|
||||
let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size);
|
||||
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
||||
let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size);
|
||||
let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size);
|
||||
(0..cols_out).for_each(|i| {
|
||||
<Self as VecZnxDftOps<DataMut, Data, FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
||||
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
|
||||
@@ -434,32 +480,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
res_size as u64,
|
||||
a_size as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||
a: &VecZnx<Data>,
|
||||
b: &MatZnxDft<Data, FFT64>,
|
||||
scratch: &mut ScratchSpace,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -493,6 +519,16 @@ where
|
||||
// );
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as MatZnxDftScratch>::vmp_apply_dft_tmp_bytes(
|
||||
self,
|
||||
c.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
self.ptr,
|
||||
@@ -504,39 +540,17 @@ where
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
res_cols: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
a_cols: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
(res_size * res_cols) as u64,
|
||||
(a_size * a_cols) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||
a: &VecZnxDft<Data, FFT64>,
|
||||
b: &MatZnxDft<Data, FFT64>,
|
||||
scratch: &mut ScratchSpace,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -572,6 +586,17 @@ where
|
||||
// );
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||
c.cols(),
|
||||
c.size(),
|
||||
a.cols(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
@@ -582,7 +607,7 @@ where
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
(b.size() * b.cols()) as u64,
|
||||
scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(),
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -590,6 +615,7 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::ScratchOwned;
|
||||
use crate::mat_znx_dft_ops::*;
|
||||
use crate::vec_znx_big_ops::*;
|
||||
use crate::vec_znx_dft_ops::*;
|
||||
@@ -617,7 +643,9 @@ mod tests {
|
||||
|
||||
// let mut tmp_bytes: Vec<u8> =
|
||||
// alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch = ScratchSpace {};
|
||||
let mut scratch = ScratchOwned::new(
|
||||
2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()),
|
||||
);
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
alloc_aligned::<u8>(<Module<FFT64> as VecZnxDftOps<Vec<u8>, Vec<u8>, _>>::vec_znx_idft_tmp_bytes(&module));
|
||||
|
||||
@@ -630,7 +658,9 @@ mod tests {
|
||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
|
||||
module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch);
|
||||
// let g = vmpmat_0.to_mut();
|
||||
|
||||
module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow());
|
||||
|
||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
|
||||
@@ -641,11 +671,25 @@ mod tests {
|
||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||
module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch);
|
||||
// module.vmp_extract_row(
|
||||
// log_base2k,
|
||||
// &mut b.to_mut(),
|
||||
// &vmpmat_0.to_ref(),
|
||||
// row_i,
|
||||
// col_in,
|
||||
// scratch.borrow(),
|
||||
// );
|
||||
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch);
|
||||
module.vec_znx_big_normalize(
|
||||
log_base2k,
|
||||
&mut a.to_mut(),
|
||||
col_out,
|
||||
&a_big.to_ref(),
|
||||
col_out,
|
||||
scratch.borrow(),
|
||||
);
|
||||
});
|
||||
|
||||
assert_eq!(a.raw(), b.raw());
|
||||
|
||||
Reference in New Issue
Block a user