reworked scalar

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 23:11:43 +02:00
parent 6f7b93c7ca
commit 9ade995cd7
8 changed files with 311 additions and 338 deletions

View File

@@ -160,7 +160,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, row_i: usize, a: &MatZnxDft<B>);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
///
@@ -170,7 +170,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
///
@@ -404,7 +404,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, row_i: usize) {
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, row_i: usize, a: &MatZnxDft<FFT64>) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
@@ -422,14 +422,14 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize {
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
res_size as u64,
a_size as u64,
gct_rows as u64,
gct_size as u64,
b_rows as u64,
b_size as u64,
) as usize
}
}
@@ -595,7 +595,7 @@ mod tests {
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0);
assert_eq!(a_dft.raw(), b_dft.raw());
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)