added vec_znx_idft_consume

This commit is contained in:
Jean-Philippe Bossuat
2025-05-08 10:16:20 +02:00
parent 48ac28c4ce
commit 2ec905bbc3
3 changed files with 36 additions and 3 deletions

View File

@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView}; use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, ZnxSliceSize, alloc_aligned};
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;

View File

@@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
use std::fmt; use std::fmt;
pub struct VecZnxDft<D, B: Backend> { pub struct VecZnxDft<D, B: Backend> {
@@ -13,6 +13,12 @@ pub struct VecZnxDft<D, B: Backend> {
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
} }
impl<D, B: Backend> VecZnxDft<D, B> {
pub fn into_big(self) -> VecZnxBig<D, B> {
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
}
}
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> { impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
fn cols(&self) -> usize { fn cols(&self) -> usize {
self.cols self.cols

View File

@@ -1,7 +1,10 @@
use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::ffi::{vec_znx_big, vec_znx_dft};
use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize}; use crate::{
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
ZnxSliceSize,
};
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
use std::cmp::min; use std::cmp::min;
@@ -44,6 +47,9 @@ pub trait VecZnxDftOps<B: Backend> {
where where
R: VecZnxBigToMut<B>, R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>; A: VecZnxDftToMut<B>;
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>, a_cols: usize) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where where
@@ -97,6 +103,27 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>, a_col: usize) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
unsafe {
(0..a_mut.size()).for_each(|j| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
a.into_big()
}
}
fn vec_znx_idft_tmp_bytes(&self) -> usize { fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
} }