mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
fixed all tests
This commit is contained in:
@@ -40,7 +40,7 @@ fn main() {
|
||||
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
|
||||
|
||||
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_inplace(
|
||||
@@ -102,7 +102,7 @@ fn main() {
|
||||
// Decryption
|
||||
|
||||
// DFT(ct[1] * s)
|
||||
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_inplace(
|
||||
&mut buf_dft,
|
||||
0, // Selects the first column of res.
|
||||
|
||||
@@ -393,7 +393,7 @@ mod tests {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
a.fill_uniform(basek, col_out, mat_size, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
|
||||
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
|
||||
@@ -453,7 +453,7 @@ mod tests {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
@@ -462,7 +462,7 @@ mod tests {
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
@@ -489,14 +489,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_add() {
|
||||
let log_n: i32 = 5;
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 6;
|
||||
let mat_size: usize = 6;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 5;
|
||||
let res_size: usize = a_size;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
@@ -521,10 +522,8 @@ mod tests {
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
(0..a_size).for_each(|j| {
|
||||
a.at_mut(i, j)[i + 1] = 1 + j as i64;
|
||||
});
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
@@ -539,9 +538,9 @@ mod tests {
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
@@ -550,12 +549,12 @@ mod tests {
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
c_dft.zero();
|
||||
(0..c_dft.cols()).for_each(|i| {
|
||||
module.vec_znx_dft(&mut c_dft, i, &a, 0);
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0);
|
||||
});
|
||||
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
|
||||
@@ -582,16 +581,118 @@ mod tests {
|
||||
);
|
||||
(0..res_cols).for_each(|i| {
|
||||
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
assert_eq!(res_want, res_have);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64);
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
#[test]
|
||||
fn vmp_apply_digits() {
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 6;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
[1, 3, 6].iter().for_each(|digits| {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let a_cols: usize = *in_cols;
|
||||
let res_cols: usize = *out_cols;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
let rows: usize = a.size() / digits;
|
||||
|
||||
let shift: usize = 1;
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..rows).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx: usize = shift + col_in_i * mat_cols_out + col_out_i;
|
||||
let limb: usize = (digits - 1) + row_i * digits;
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 0 as i64;
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits);
|
||||
|
||||
(0..*digits).for_each(|di| {
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
} else {
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow());
|
||||
}
|
||||
});
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..res_cols).for_each(|col_i| {
|
||||
(0..a_cols).for_each(|j| {
|
||||
module.vec_znx_rotate(
|
||||
(col_i + j * mat_cols_out + shift) as i64,
|
||||
&mut tmp,
|
||||
0,
|
||||
&a,
|
||||
j,
|
||||
);
|
||||
module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0);
|
||||
});
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(res_have, res_want)
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -23,6 +23,7 @@ use std::{cmp::min, fmt};
|
||||
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
|
||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||
/// are small polynomials of Zn\[X\].
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnx<D> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
@@ -30,6 +31,15 @@ pub struct VecZnx<D> {
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl<D> fmt::Debug for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::marker::PhantomData;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned,
|
||||
};
|
||||
use std::fmt;
|
||||
|
||||
@@ -91,39 +91,6 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<C: AsRef<[u8]>>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
|
||||
where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||
|
||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||
|
||||
@@ -53,7 +53,7 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
@@ -74,9 +74,9 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
@@ -158,14 +158,18 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size(){
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
@@ -224,32 +228,30 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.cols < a_col
|
||||
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user