diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index 515d616..07f3c8d 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 515d616a8ba858b7e858a63dc0fa768eb70ebb99 +Subproject commit 07f3c8d2b8cf9e3a41e9577d06852d6f52e80e88 diff --git a/base2k/src/ffi/vec_znx_dft.rs b/base2k/src/ffi/vec_znx_dft.rs index 54fb117..bcdf5e2 100644 --- a/base2k/src/ffi/vec_znx_dft.rs +++ b/base2k/src/ffi/vec_znx_dft.rs @@ -75,3 +75,19 @@ unsafe extern "C" { a_size: u64, ); } + +unsafe extern "C" { + pub unsafe fn vec_znx_dft_automorphism( + module: *const MODULE, + d: i64, + res_dft: *mut VEC_ZNX_DFT, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + tmp: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 4ad8525..367270d 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{alloc_aligned, VecZnx}; +use crate::{alloc_aligned, VecZnx, DEFAULTALIGN}; use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND}; pub struct VecZnxDft { @@ -135,15 +135,28 @@ pub trait VecZnxDftOps { /// b <- IDFT(a), uses a as scratch space. fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize); - fn vec_znx_idft( + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]); + + fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize); + + fn vec_znx_dft_automorphism( &self, - b: &mut VecZnxBig, + k: i64, + b: &mut VecZnxDft, + b_cols: usize, a: &VecZnxDft, a_cols: usize, + ); + + fn vec_znx_dft_automorphism_inplace( + &self, + k: i64, + a: &mut VecZnxDft, + a_cols: usize, tmp_bytes: &mut [u8], ); - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize); + fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; } impl VecZnxDftOps for Module { @@ -161,10 +174,10 @@ impl VecZnxDftOps for Module { fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { debug_assert!( - tmp_bytes.len() >= ::bytes_of_vec_znx_dft(self, cols), + tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols), "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", tmp_bytes.len(), - ::bytes_of_vec_znx_dft(self, cols) + Self::bytes_of_vec_znx_dft(self, cols) ); #[cfg(debug_assertions)] { @@ -223,33 +236,27 @@ impl VecZnxDftOps for Module { } // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft( - &self, - b: &mut VecZnxBig, - a: &VecZnxDft, - a_cols: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - b.cols() >= a_cols, - "invalid c_vector: b.cols()={} < a_cols={}", - b.cols(), - a_cols - ); - debug_assert!( - a.cols() >= a_cols, - "invalid c_vector: a.cols()={} < a_cols={}", - a.cols(), - a_cols - ); - debug_assert!( - tmp_bytes.len() >= ::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - ::vec_znx_idft_tmp_bytes(self) - ); + fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { + assert!( + b.cols() >= a_cols, + "invalid c_vector: b.cols()={} < a_cols={}", + b.cols(), + a_cols + ); + assert!( + a.cols() >= a_cols, + "invalid c_vector: a.cols()={} < a_cols={}", + a.cols(), + a_cols + ); + assert!( + tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_idft_tmp_bytes(self) + ); assert_alignement(tmp_bytes.as_ptr()) } unsafe { @@ -263,4 +270,132 @@ impl VecZnxDftOps for Module { ) } } + + fn vec_znx_dft_automorphism( + &self, + k: i64, + b: &mut VecZnxDft, + b_cols: usize, + a: &VecZnxDft, + a_cols: usize, + ) { + #[cfg(debug_assertions)] + { + assert!( + b.cols() >= a_cols, + "invalid c_vector: b.cols()={} < a_cols={}", + b.cols(), + a_cols + ); + assert!( + a.cols() >= a_cols, + "invalid c_vector: a.cols()={} < a_cols={}", + a.cols(), + a_cols + ); + } + unsafe { + vec_znx_dft::vec_znx_dft_automorphism( + self.ptr, + k, + b.ptr as *mut vec_znx_dft_t, + b_cols as u64, + a.ptr as *const vec_znx_dft_t, + a_cols as u64, + [0u8; 0].as_mut_ptr(), + ); + } + } + + fn vec_znx_dft_automorphism_inplace( + &self, + k: i64, + a: &mut VecZnxDft, + a_cols: usize, + tmp_bytes: &mut [u8], + ) { + #[cfg(debug_assertions)] + { + assert!( + a.cols() >= a_cols, + "invalid c_vector: a.cols()={} < a_cols={}", + a.cols(), + a_cols + ); + assert!( + tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self), + "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}", + tmp_bytes.len(), + Self::vec_znx_dft_automorphism_tmp_bytes(self) + ); + assert_alignement(tmp_bytes.as_ptr()) + } + unsafe { + vec_znx_dft::vec_znx_dft_automorphism( + self.ptr, + k, + a.ptr as *mut vec_znx_dft_t, + a_cols as u64, + a.ptr as *const vec_znx_dft_t, + a_cols as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { + unsafe { + std::cmp::max( + vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize, + DEFAULTALIGN, + ) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + alloc_aligned, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND, + }; + use itertools::izip; + use sampling::source::{new_seed, Source}; + + #[test] + fn test_automorphism_dft() { + let module: Module = Module::new(128, BACKEND::FFT64); + + let cols: usize = 2; + let log_base2k: usize = 17; + let mut a: VecZnx = module.new_vec_znx(cols); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(cols); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(cols); + + let mut source: Source = Source::new(new_seed()); + module.fill_uniform(log_base2k, &mut a, cols, &mut source); + + let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); + + let p: i64 = -5; + + // a_dft <- DFT(a) + module.vec_znx_dft(&mut a_dft, &a, cols); + + // a_dft <- AUTO(a_dft) + module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes); + + // a <- AUTO(a) + module.vec_znx_automorphism_inplace(p, &mut a, cols); + + // b_dft <- DFT(AUTO(a)) + module.vec_znx_dft(&mut b_dft, &a, cols); + + let a_f64: &[f64] = a_dft.raw(&module); + let b_f64: &[f64] = b_dft.raw(&module); + izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { + assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); + }); + + module.free() + } }