This commit is contained in:
Janmajaya Mall
2025-05-04 18:39:28 +05:30
parent ff8370e023
commit b82a1ca1b4
10 changed files with 551 additions and 446 deletions

View File

@@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned,
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps,
};
pub trait MatZnxDftAlloc<B> {
@@ -36,12 +36,55 @@ pub trait MatZnxDftAlloc<B> {
// ) -> MatZnxDft<FFT64>;
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
pub trait MatZnxDftScratch {
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(
&self,
c_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnxDft].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
c_size: usize,
a_cols: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
///
/// # Arguments
@@ -58,12 +101,9 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
b_row: usize,
b_col_in: usize,
a: &VecZnx<Data>,
scratch: &mut ScratchSpace,
scratch: &mut ScratchBorr,
);
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
///
/// # Arguments
@@ -78,7 +118,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
a: &MatZnxDft<Data, B>,
b_row: usize,
b_col_in: usize,
scratch: &mut ScratchSpace,
scratch: &mut ScratchBorr,
);
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
@@ -101,24 +141,6 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, B>, a: &MatZnxDft<Data, B>, a_row: usize, a_col_in: usize);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(
&self,
c_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
@@ -143,27 +165,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchSpace);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnxDft].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
c_size: usize,
a_cols: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchBorr);
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -195,7 +197,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
c: &mut VecZnxDft<DataMut, B>,
a: &VecZnxDft<Data, B>,
b: &MatZnxDft<Data, B>,
scratch: &mut ScratchSpace,
scratch: &mut ScratchBorr,
);
}
@@ -220,22 +222,70 @@ impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
}
}
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
where
DataMut: AsMut<[u8]> + AsRef<[u8]>,
Data: AsRef<[u8]>,
{
impl<B: Backend> MatZnxDftScratch for Module<B> {
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
<Self as VecZnxDftAlloc<FFT64>>::bytes_of_vec_znx_dft(self, cols_out, size)
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
}
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
+ <Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(self)
}
fn vmp_apply_dft_tmp_bytes(
&self,
c_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
c_size as u64,
a_size as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
c_size: usize,
a_cols: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
(c_size * c_cols) as u64,
(a_size * a_cols) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
}
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
where
DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>,
Data: AsRef<[u8]>,
{
fn vmp_prepare_row(
&self,
b: &mut MatZnxDft<DataMut, FFT64>,
b_row: usize,
b_col_in: usize,
a: &VecZnx<Data>,
scratch: &mut ScratchSpace,
scratch: &mut ScratchBorr,
) {
#[cfg(debug_assertions)]
{
@@ -278,17 +328,13 @@ where
let a_size: usize = a.size();
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
let mut a_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, a_size);
let (mut a_dft, _) = scratch.tmp_scalar_slice(12);
DataMut::from(a_dft);
// let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<DataMut, _>(self, cols_out, a_size);
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
}
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
self.bytes_of_vec_znx_dft(cols_out, size)
+ <Self as VecZnxBigOps<DataMut, Data, FFT64>>::vec_znx_big_normalize_tmp_bytes(self)
}
fn vmp_extract_row(
&self,
log_base2k: usize,
@@ -296,7 +342,7 @@ where
a: &MatZnxDft<Data, FFT64>,
a_row: usize,
a_col_in: usize,
scratch: &mut ScratchSpace,
mut scratch: &mut ScratchBorr,
) {
#[cfg(debug_assertions)]
{
@@ -336,9 +382,9 @@ where
let size: usize = b.size();
// let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
let mut b_dft = scratch.tmp_vec_znx_dft::<DataMut, _>(self.n(), cols_out, size);
let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size);
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size);
let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size);
(0..cols_out).for_each(|i| {
<Self as VecZnxDftOps<DataMut, Data, FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
@@ -434,32 +480,12 @@ where
}
}
fn vmp_apply_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
res_size as u64,
a_size as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
fn vmp_apply_dft(
&self,
c: &mut VecZnxDft<DataMut, FFT64>,
a: &VecZnx<Data>,
b: &MatZnxDft<Data, FFT64>,
scratch: &mut ScratchSpace,
mut scratch: &mut ScratchBorr,
) {
#[cfg(debug_assertions)]
{
@@ -493,6 +519,16 @@ where
// );
// assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as MatZnxDftScratch>::vmp_apply_dft_tmp_bytes(
self,
c.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft(
self.ptr,
@@ -504,39 +540,17 @@ where
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
res_cols: usize,
res_size: usize,
a_size: usize,
a_cols: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
(res_size * res_cols) as u64,
(a_size * a_cols) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
fn vmp_apply_dft_to_dft(
&self,
c: &mut VecZnxDft<DataMut, FFT64>,
a: &VecZnxDft<Data, FFT64>,
b: &MatZnxDft<Data, FFT64>,
scratch: &mut ScratchSpace,
mut scratch: &mut ScratchBorr,
) {
#[cfg(debug_assertions)]
{
@@ -572,6 +586,17 @@ where
// );
// assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes(
c.cols(),
c.size(),
a.cols(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
@@ -582,7 +607,7 @@ where
b.as_ptr() as *const vmp::vmp_pmat_t,
b.rows() as u64,
(b.size() * b.cols()) as u64,
scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(),
tmp_bytes.as_mut_ptr(),
)
}
}
@@ -590,6 +615,7 @@ where
#[cfg(test)]
mod tests {
use crate::ScratchOwned;
use crate::mat_znx_dft_ops::*;
use crate::vec_znx_big_ops::*;
use crate::vec_znx_dft_ops::*;
@@ -617,7 +643,9 @@ mod tests {
// let mut tmp_bytes: Vec<u8> =
// alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
let mut scratch = ScratchSpace {};
let mut scratch = ScratchOwned::new(
2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()),
);
let mut tmp_bytes: Vec<u8> =
alloc_aligned::<u8>(<Module<FFT64> as VecZnxDftOps<Vec<u8>, Vec<u8>, _>>::vec_znx_idft_tmp_bytes(&module));
@@ -630,7 +658,9 @@ mod tests {
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch);
// let g = vmpmat_0.to_mut();
module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow());
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
@@ -641,11 +671,25 @@ mod tests {
assert_eq!(a_dft.raw(), b_dft.raw());
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch);
// module.vmp_extract_row(
// log_base2k,
// &mut b.to_mut(),
// &vmpmat_0.to_ref(),
// row_i,
// col_in,
// scratch.borrow(),
// );
(0..mat_cols_out).for_each(|col_out| {
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch);
module.vec_znx_big_normalize(
log_base2k,
&mut a.to_mut(),
col_out,
&a_big.to_ref(),
col_out,
scratch.borrow(),
);
});
assert_eq!(a.raw(), b.raw());