wip on BR + added enc/dec for LWE

This commit is contained in:
Jean-Philippe Bossuat
2025-06-13 20:45:24 +02:00
parent e8cfb5e2ab
commit 829b8be610
43 changed files with 745 additions and 688 deletions

View File

@@ -103,7 +103,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * size_of::<T>()) % align,
(size * size_of::<T>()) % (align/ size_of::<T>()),
0,
"size={} must be a multiple of align={}",
size,
@@ -121,7 +121,7 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(
size + (size % (DEFAULTALIGN / size_of::<T>())),
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))),
DEFAULTALIGN,
)
}

View File

@@ -82,6 +82,12 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
where
A: MatZnxToMut<FFT64>;
/// Multiplies A by (X^{k} - 1).
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
where
R: MatZnxToMut<FFT64>,
A: MatZnxToRef<FFT64>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
@@ -212,7 +218,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1);
});
})
});
}
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
@@ -249,7 +255,52 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1);
});
})
});
}
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
where
R: MatZnxToMut<FFT64>,
A: MatZnxToRef<FFT64>,
{
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
{
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
xpm1.data[0] = 1;
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
}
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
(0..a.rows()).for_each(|row_i| {
(0..a.cols_in()).for_each(|col_j| {
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
(0..tmp_0.cols()).for_each(|i| {
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
});
self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j);
(0..tmp_0.cols()).for_each(|i| {
self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i);
});
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0);
});
});
}
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
@@ -845,7 +896,6 @@ mod tests {
(0..cols_out).for_each(|j| {
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
// module.vec_znx_big_normalize(basek, &mut want, j, &tmp_big, 0, scratch.borrow());
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
@@ -863,4 +913,84 @@ mod tests {
});
});
}
#[test]
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 8;
let rows: usize = 2;
let cols_in: usize = 2;
let cols_out: usize = 2;
let size: usize = 4;
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
let mut source: Source = Source::new([0u8; 32]);
(0..mat_have.rows()).for_each(|row_i| {
(0..mat_have.cols_in()).for_each(|col_i| {
(0..cols_out).for_each(|j| {
tmp.fill_uniform(basek, 0, size, &mut source);
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
});
module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft);
});
});
(0..mat_want.rows()).for_each(|row_i| {
(0..mat_want.cols_in()).for_each(|col_i| {
(0..cols_out).for_each(|j| {
tmp.fill_uniform(basek, 0, size, &mut source);
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
});
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
});
});
let k: i64 = 1;
module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow());
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
let mut source: Source = Source::new([0u8; 32]);
(0..mat_want.rows()).for_each(|row_i| {
(0..mat_want.cols_in()).for_each(|col_i| {
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
(0..cols_out).for_each(|j| {
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
tmp.fill_uniform(basek, 0, size, &mut source);
module.vec_znx_add_inplace(&mut want, j, &tmp, 0);
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
});
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
(0..cols_out).for_each(|j| {
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
});
assert_eq!(have, want)
});
});
}
}

View File

@@ -91,9 +91,9 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
}
pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) {
assert!(block_size & (block_size - 1) == 0);
assert!(self.n() % block_size == 0);
let max_idx: u64 = (block_size + 1) as u64;
let mask_idx: u64 = (2 * block_size - 1) as u64;
let mask_idx: u64 = (1<<((u64::BITS - max_idx.leading_zeros())as u64)) - 1 ;
for block in self.at_mut(col, 0).chunks_mut(block_size) {
let idx: usize = source.next_u64n(max_idx, mask_idx) as usize;
if idx != block_size {

View File

@@ -177,7 +177,7 @@ impl<D: From<Vec<u8>>> VecZnx<D> {
n * cols * size * size_of::<Scalar>()
}
pub(crate) fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
Self {
data: data.into(),
@@ -243,7 +243,13 @@ fn normalize_tmp_bytes(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
#[allow(dead_code)]
impl<D: AsRef<[u8]> + AsMut<[u8]>> VecZnx<D>{
pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]){
normalize(basek, self, a_col, tmp_bytes);
}
}
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n();