This commit is contained in:
Jean-Philippe Bossuat
2025-04-22 18:50:51 +02:00
parent d3e3594ae8
commit fbdb4436b2
18 changed files with 908 additions and 403 deletions

View File

@@ -61,14 +61,6 @@ impl VecZnxDft {
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn backend(&self) -> BACKEND {
self.backend
}
@@ -102,12 +94,36 @@ impl VecZnxDft {
}
}
impl Infos for VecZnxDft {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
pub trait VecZnxDftOps {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
@@ -117,6 +133,19 @@ pub trait VecZnxDftOps {
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
@@ -133,28 +162,15 @@ pub trait VecZnxDftOps {
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
fn vec_znx_dft_automorphism(
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
);
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
}
@@ -173,37 +189,25 @@ impl VecZnxDftOps for Module {
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!(
tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
tmp_bytes.len(),
Self::bytes_of_vec_znx_dft(self, cols)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
VecZnxDft::from_bytes(self, cols, tmp_bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft::from_bytes_borrow(self, cols, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
debug_assert!(
b.cols() >= a_cols,
"invalid c_vector: b_vector.cols()={} < a_cols={}",
b.cols(),
a_cols
);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
)
}
}
@@ -216,41 +220,23 @@ impl VecZnxDftOps for Module {
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) {
debug_assert!(
b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
a.as_ptr(),
a_cols as u64,
a.cols() as u64,
a.n() as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) {
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
@@ -263,65 +249,31 @@ impl VecZnxDftOps for Module {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_big_t,
a.cols() as u64,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_dft_automorphism(
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
}
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b_cols as u64,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
@@ -335,9 +287,9 @@ impl VecZnxDftOps for Module {
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -379,16 +331,16 @@ mod tests {
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a, cols);
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes);
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, cols);
module.vec_znx_automorphism_inplace(p, &mut a);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a, cols);
module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw(&module);
let b_f64: &[f64] = b_dft.raw(&module);