fixed all tests

This commit is contained in:
Jean-Philippe Bossuat
2025-06-06 14:06:36 +02:00
parent 33795df6c2
commit 113231da55
28 changed files with 1817 additions and 959 deletions

View File

@@ -40,7 +40,7 @@ fn main() {
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
// Applies DFT(ct[1]) * DFT(s)
module.svp_apply_inplace(
@@ -102,7 +102,7 @@ fn main() {
// Decryption
// DFT(ct[1] * s)
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
module.svp_apply_inplace(
&mut buf_dft,
0, // Selects the first column of res.

View File

@@ -393,7 +393,7 @@ mod tests {
let mut source: Source = Source::new([0u8; 32]);
(0..mat_cols_out).for_each(|col_out| {
a.fill_uniform(basek, col_out, mat_size, &mut source);
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
@@ -453,7 +453,7 @@ mod tests {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
@@ -462,7 +462,7 @@ mod tests {
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
});
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
@@ -489,14 +489,15 @@ mod tests {
#[test]
fn vmp_apply_add() {
let log_n: i32 = 5;
let log_n: i32 = 4;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 8;
let a_size: usize = 6;
let mat_size: usize = 6;
let a_size: usize = 5;
let mat_size: usize = 5;
let res_size: usize = a_size;
let mut source: Source = Source::new([0u8; 32]);
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
@@ -521,10 +522,8 @@ mod tests {
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
(0..a_size).for_each(|j| {
a.at_mut(i, j)[i + 1] = 1 + j as i64;
});
(0..a_cols).for_each(|col_i| {
a.fill_uniform(basek, col_i, a.size(), &mut source);
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
@@ -539,9 +538,9 @@ mod tests {
(0..a.size()).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
@@ -550,12 +549,12 @@ mod tests {
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
});
c_dft.zero();
(0..c_dft.cols()).for_each(|i| {
module.vec_znx_dft(&mut c_dft, i, &a, 0);
module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0);
});
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
@@ -582,16 +581,118 @@ mod tests {
);
(0..res_cols).for_each(|i| {
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow());
});
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
assert_eq!(res_want, res_have);
});
});
});
}
(0..mat_cols_out).for_each(|col_i| {
res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64);
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
#[test]
fn vmp_apply_digits() {
let log_n: i32 = 4;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 8;
let a_size: usize = 6;
let mat_size: usize = 6;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
[1, 3, 6].iter().for_each(|digits| {
let mut source: Source = Source::new([0u8; 32]);
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|col_i| {
a.fill_uniform(basek, col_i, a.size(), &mut source);
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
let rows: usize = a.size() / digits;
let shift: usize = 1;
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
(0..rows).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx: usize = shift + col_in_i * mat_cols_out + col_out_i;
let limb: usize = (digits - 1) + row_i * digits;
tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, limb)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits);
(0..*digits).for_each(|di| {
(0..a_cols).for_each(|col_i| {
module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i);
});
if di == 0 {
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
} else {
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow());
}
});
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
});
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
(0..res_cols).for_each(|col_i| {
(0..a_cols).for_each(|j| {
module.vec_znx_rotate(
(col_i + j * mat_cols_out + shift) as i64,
&mut tmp,
0,
&a,
j,
);
module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0);
});
module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow());
});
assert_eq!(res_have, res_want)
});
});
});

View File

@@ -23,6 +23,7 @@ use std::{cmp::min, fmt};
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\].
#[derive(PartialEq, Eq)]
pub struct VecZnx<D> {
pub data: D,
pub n: usize,
@@ -30,6 +31,15 @@ pub struct VecZnx<D> {
pub size: usize,
}
impl<D> fmt::Debug for VecZnx<D>
where
D: AsRef<[u8]>,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D> ZnxInfos for VecZnx<D> {
fn cols(&self) -> usize {
self.cols

View File

@@ -3,7 +3,7 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned,
};
use std::fmt;
@@ -91,39 +91,6 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
}
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C: AsRef<[u8]>>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
impl<D, B: Backend> VecZnxDft<D, B> {

View File

@@ -53,7 +53,7 @@ pub trait VecZnxDftOps<B: Backend> {
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
@@ -74,9 +74,9 @@ pub trait VecZnxDftOps<B: Backend> {
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
R: VecZnxDftToMut<FFT64>,
A: VecZnxToRef;
}
@@ -150,7 +150,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
@@ -158,14 +158,18 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
let steps: usize = (a_ref.size() + step - 1) / step;
let min_steps: usize = min(res_mut.size(), steps);
(0..min_size).for_each(|j| {
res_mut
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a_ref.size(){
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, j));
.copy_from_slice(a_ref.at(a_col, limb));
}
});
(min_size..res_mut.size()).for_each(|j| {
(min_steps..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
@@ -224,32 +228,30 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_col
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxToRef,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
let steps: usize = (a_ref.size() + step - 1) / step;
let min_steps: usize = min(res_mut.size(), steps);
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_dft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
a_ref.at_ptr(a_col, j),
1 as u64,
a_ref.sl() as u64,
)
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a_ref.size() {
vec_znx_dft::vec_znx_dft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
a_ref.at_ptr(a_col, limb),
1 as u64,
a_ref.sl() as u64,
)
}
});
(min_size..res_mut.size()).for_each(|j| {
(min_steps..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}