diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 79270ea..1a9f4b3 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -40,14 +40,14 @@ fn main() { let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); + module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + // Applies DFT(ct[1]) * DFT(s) - module.svp_apply( + module.svp_apply_dft_inplace( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) 0, // Selects the first column of s_dft - &ct, - 1, // Selects the second column of ct ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 8135d85..b6fa494 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 8135d85e7ac14601568fdd228e7dedf88994f7cf +Subproject commit b6fa494a14c52842712f8ff032ea80812467dec2 diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs index 71c871d..9d4999f 100644 --- a/base2k/src/ffi/svp.rs +++ b/base2k/src/ffi/svp.rs @@ -33,3 +33,14 @@ unsafe extern "C" { a_sl: u64, ); } + +unsafe extern "C" { + pub unsafe fn svp_apply_dft_to_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + ppol: *const SVP_PPOL, + a: *const VEC_ZNX_DFT, + a_size: u64, + ); +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index ea98a57..a4b3ccc 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -3,14 +3,13 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx, - VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize, + VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize, }; pub trait ScalarZnxDftAlloc { fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; - // fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; } pub trait ScalarZnxDftOps { @@ -22,7 +21,11 @@ pub trait ScalarZnxDftOps { where R: VecZnxDftToMut, A: ScalarZnxDftToRef, - B: VecZnxToRef; + B: VecZnxDftToRef; + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef; } impl ScalarZnxDftAlloc for Module { @@ -58,20 +61,38 @@ impl ScalarZnxDftOps for Module { where R: VecZnxDftToMut, A: ScalarZnxDftToRef, - B: VecZnxToRef, + B: VecZnxDftToRef, { let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); + let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); unsafe { - svp::svp_apply_dft( + svp::svp_apply_dft_to_dft( self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - b.at_ptr(b_col, 0), + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, b.size() as u64, - b.sl() as u64, + ) + } + } + + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, ) } }