From 2ec905bbc36cfa88f78b7d02044e9150e30c5ca4 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 8 May 2025 10:16:20 +0200 Subject: [PATCH] added vec_znx_idft_consume --- base2k/src/vec_znx_big.rs | 2 +- base2k/src/vec_znx_dft.rs | 8 +++++++- base2k/src/vec_znx_dft_ops.rs | 29 ++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index d8c1bdd..2875b97 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, ZnxSliceSize, alloc_aligned}; use std::fmt; use std::marker::PhantomData; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 0e7f952..61e1be5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; use std::fmt; pub struct VecZnxDft { @@ -13,6 +13,12 @@ pub struct VecZnxDft { _phantom: PhantomData, } +impl VecZnxDft { + pub fn into_big(self) -> VecZnxBig { + VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) + } +} + impl ZnxInfos for VecZnxDft { fn cols(&self) -> usize { self.cols diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 927e39e..cf06cc2 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,7 +1,10 @@ use crate::ffi::{vec_znx_big, vec_znx_dft}; use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; -use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize}; +use crate::{ + Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, + ZnxSliceSize, +}; use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; use std::cmp::min; @@ -44,6 +47,9 @@ pub trait VecZnxDftOps { where R: VecZnxBigToMut, A: VecZnxDftToMut; + fn vec_znx_idft_consume(&self, a: VecZnxDft, a_cols: usize) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut; fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where @@ -97,6 +103,27 @@ impl VecZnxDftOps for Module { } } + fn vec_znx_idft_consume(&self, mut a: VecZnxDft, a_col: usize) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + unsafe { + (0..a_mut.size()).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + + a.into_big() + } + } + fn vec_znx_idft_tmp_bytes(&self) -> usize { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } }