Updated svp

This commit is contained in:
Jean-Philippe Bossuat
2025-05-06 16:02:32 +02:00
parent 669450c4f1
commit f9b194cca1
4 changed files with 44 additions and 12 deletions

View File

@@ -40,14 +40,14 @@ fn main() {
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
// Applies DFT(ct[1]) * DFT(s)
module.svp_apply(
module.svp_apply_dft_inplace(
&mut buf_dft, // DFT(ct[1] * s)
0, // Selects the first column of res
&s_dft, // DFT(s)
0, // Selects the first column of s_dft
&ct,
1, // Selects the second column of ct
);
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)

View File

@@ -33,3 +33,14 @@ unsafe extern "C" {
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft_to_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
ppol: *const SVP_PPOL,
a: *const VEC_ZNX_DFT,
a_size: u64,
);
}

View File

@@ -3,14 +3,13 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx,
VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize,
VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
// fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
@@ -22,7 +21,11 @@ pub trait ScalarZnxDftOps<BACKEND: Backend> {
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxToRef;
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
@@ -58,20 +61,38 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxToRef,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft(
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0),
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.sl() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
)
}
}