updated sampling traits

This commit is contained in:
Jean-Philippe Bossuat
2025-05-06 11:30:55 +02:00
parent ffa363804b
commit 08e81f50c9
22 changed files with 251 additions and 2778 deletions

View File

@@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef,
};
@@ -13,7 +13,7 @@ pub trait MatZnxDftAlloc<B: Backend> {
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned<B>;
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
@@ -24,7 +24,7 @@ pub trait MatZnxDftAlloc<B: Backend> {
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftAllocOwned<B>;
) -> MatZnxDftOwned<B>;
}
pub trait MatZnxDftScratch {
@@ -103,11 +103,11 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
MatZnxDftAllocOwned::bytes_of(self, rows, cols_in, cols_out, size)
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftAllocOwned<B> {
MatZnxDftAllocOwned::new(self, rows, cols_in, cols_out, size)
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft_from_bytes(
@@ -117,8 +117,8 @@ impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftAllocOwned<B> {
MatZnxDftAllocOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
}
}
@@ -305,8 +305,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(test)]
mod tests {
use crate::{
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
Encoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
};
use sampling::source::Source;
@@ -329,7 +329,7 @@ mod tests {
for row_i in 0..mat_rows {
let mut source: Source = Source::new([0u8; 32]);
(0..mat_cols_out).for_each(|col_out| {
module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source);
a.fill_uniform(log_base2k, col_out, mat_size, &mut source);
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);