Added support for automorphism in DFT

This commit is contained in:
Jean-Philippe Bossuat
2025-04-16 18:12:27 +02:00
parent db01092c5e
commit 52c78c9085
3 changed files with 183 additions and 32 deletions

View File

@@ -75,3 +75,19 @@ unsafe extern "C" {
a_size: u64, a_size: u64,
); );
} }
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism(
module: *const MODULE,
d: i64,
res_dft: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -1,7 +1,7 @@
use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{alloc_aligned, VecZnx}; use crate::{alloc_aligned, VecZnx, DEFAULTALIGN};
use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND}; use crate::{assert_alignement, Infos, Module, VecZnxBig, BACKEND};
pub struct VecZnxDft { pub struct VecZnxDft {
@@ -135,15 +135,28 @@ pub trait VecZnxDftOps {
/// b <- IDFT(a), uses a as scratch space. /// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize); fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
fn vec_znx_idft( fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
fn vec_znx_dft_automorphism(
&self, &self,
b: &mut VecZnxBig, k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft, a: &VecZnxDft,
a_cols: usize, a_cols: usize,
);
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize); fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
} }
impl VecZnxDftOps for Module { impl VecZnxDftOps for Module {
@@ -161,10 +174,10 @@ impl VecZnxDftOps for Module {
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!( debug_assert!(
tmp_bytes.len() >= <Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols), tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}", "invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
tmp_bytes.len(), tmp_bytes.len(),
<Module as VecZnxDftOps>::bytes_of_vec_znx_dft(self, cols) Self::bytes_of_vec_znx_dft(self, cols)
); );
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -223,33 +236,27 @@ impl VecZnxDftOps for Module {
} }
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft( fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) {
&self,
b: &mut VecZnxBig,
a: &VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
debug_assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
debug_assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
debug_assert!(
tmp_bytes.len() >= <Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxDftOps>::vec_znx_idft_tmp_bytes(self)
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr()) assert_alignement(tmp_bytes.as_ptr())
} }
unsafe { unsafe {
@@ -263,4 +270,132 @@ impl VecZnxDftOps for Module {
) )
} }
} }
fn vec_znx_dft_automorphism(
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
}
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b_cols as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
{
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_dft_automorphism_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a_cols as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe {
std::cmp::max(
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
DEFAULTALIGN,
)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
alloc_aligned, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, BACKEND,
};
use itertools::izip;
use sampling::source::{new_seed, Source};
#[test]
fn test_automorphism_dft() {
let module: Module = Module::new(128, BACKEND::FFT64);
let cols: usize = 2;
let log_base2k: usize = 17;
let mut a: VecZnx = module.new_vec_znx(cols);
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(cols);
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(cols);
let mut source: Source = Source::new(new_seed());
module.fill_uniform(log_base2k, &mut a, cols, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a, cols);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, cols);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a, cols);
let a_f64: &[f64] = a_dft.raw(&module);
let b_f64: &[f64] = b_dft.raw(&module);
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
});
module.free()
}
} }