added Added vmp_extract_row, vmp_extract_row_dft, vmp_extract_tmp_bytes, vmp_prepare_row_dft

-
This commit is contained in:
Jean-Philippe Bossuat
2025-04-16 11:31:58 +02:00
parent 4c1dbc70e5
commit 89369dcdf9
18 changed files with 293 additions and 181 deletions

View File

@@ -1,15 +1,15 @@
use crate::ffi::vec_znx_big::vec_znx_bigcoeff_t;
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{alloc_aligned, VecZnx};
use crate::{assert_alignement, Infos, Module, VecZnxBig, MODULETYPE};
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND};
pub struct VecZnxDft {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub backend: MODULETYPE,
pub backend: BACKEND,
}
impl VecZnxDft {
@@ -69,7 +69,7 @@ impl VecZnxDft {
self.cols
}
pub fn backend(&self) -> MODULETYPE {
pub fn backend(&self) -> BACKEND {
self.backend
}
@@ -133,17 +133,17 @@ pub trait VecZnxDftOps {
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
fn vec_znx_idft(
&self,
b: &mut VecZnxBig,
a: &mut VecZnxDft,
a_limbs: usize,
a: &VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
}
impl VecZnxDftOps for Module {
@@ -177,20 +177,20 @@ impl VecZnxDftOps for Module {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
debug_assert!(
b.cols() >= a_limbs,
"invalid c_vector: b_vector.cols()={} < a_limbs={}",
b.cols() >= a_cols,
"invalid c_vector: b_vector.cols()={} < a_cols={}",
b.cols(),
a_limbs
a_cols
);
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a_limbs as u64,
a_cols as u64,
)
}
}
@@ -226,7 +226,7 @@ impl VecZnxDftOps for Module {
fn vec_znx_idft(
&self,
b: &mut VecZnxBig,
a: &mut VecZnxDft,
a: &VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
@@ -243,7 +243,7 @@ impl VecZnxDftOps for Module {
a_cols
);
debug_assert!(
tmp_bytes.len() <= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
tmp_bytes.len() >= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
@@ -255,9 +255,9 @@ impl VecZnxDftOps for Module {
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_bigcoeff_t,
b.ptr as *mut vec_znx_big_t,
a.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
tmp_bytes.as_mut_ptr(),
)