mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
wip on BR + added enc/dec for LWE
This commit is contained in:
@@ -103,7 +103,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
||||
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
assert_eq!(
|
||||
(size * size_of::<T>()) % align,
|
||||
(size * size_of::<T>()) % (align/ size_of::<T>()),
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
@@ -121,7 +121,7 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(
|
||||
size + (size % (DEFAULTALIGN / size_of::<T>())),
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))),
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -82,6 +82,12 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
where
|
||||
A: MatZnxToMut<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1).
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
@@ -212,7 +218,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1);
|
||||
});
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
@@ -249,7 +255,52 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
|
||||
self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1);
|
||||
});
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
|
||||
(0..a.rows()).for_each(|row_i| {
|
||||
(0..a.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
@@ -845,7 +896,6 @@ mod tests {
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
// module.vec_znx_big_normalize(basek, &mut want, j, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
@@ -863,4 +913,84 @@ mod tests {
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let rows: usize = 2;
|
||||
let cols_in: usize = 2;
|
||||
let cols_out: usize = 2;
|
||||
let size: usize = 4;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
|
||||
|
||||
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
|
||||
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
(0..mat_have.rows()).for_each(|row_i| {
|
||||
(0..mat_have.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let k: i64 = 1;
|
||||
|
||||
module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow());
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_add_inplace(&mut want, j, &tmp, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
|
||||
});
|
||||
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(have, want)
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,9 +91,9 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
}
|
||||
|
||||
pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) {
|
||||
assert!(block_size & (block_size - 1) == 0);
|
||||
assert!(self.n() % block_size == 0);
|
||||
let max_idx: u64 = (block_size + 1) as u64;
|
||||
let mask_idx: u64 = (2 * block_size - 1) as u64;
|
||||
let mask_idx: u64 = (1<<((u64::BITS - max_idx.leading_zeros())as u64)) - 1 ;
|
||||
for block in self.at_mut(col, 0).chunks_mut(block_size) {
|
||||
let idx: usize = source.next_u64n(max_idx, mask_idx) as usize;
|
||||
if idx != block_size {
|
||||
|
||||
@@ -177,7 +177,7 @@ impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
|
||||
pub(crate) fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
@@ -243,7 +243,13 @@ fn normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<D: AsRef<[u8]> + AsMut<[u8]>> VecZnx<D>{
|
||||
pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]){
|
||||
normalize(basek, self, a_col, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user