Added mat_znx_dft_mul_x_pow_minus_one

This commit is contained in:
Jean-Philippe Bossuat
2025-06-11 18:04:57 +02:00
parent 4455afdabd
commit c77a819653
7 changed files with 297 additions and 35 deletions

View File

@@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef,
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
};
pub trait MatZnxDftAlloc<B: Backend> {
@@ -38,6 +38,8 @@ pub trait MatZnxDftScratch {
b_cols_out: usize,
b_size: usize,
) -> usize;
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize;
}
/// This trait implements methods for vector matrix product,
@@ -52,7 +54,7 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<FFT64>;
@@ -64,11 +66,22 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
/// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<FFT64>,
A: MatZnxToRef<FFT64>;
/// Multiplies A by (X^{k} - 1) and stores the result on R.
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
where
R: MatZnxToMut<FFT64>,
A: MatZnxToRef<FFT64>;
/// Multiplies A by (X^{k} - 1).
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
where
A: MatZnxToMut<FFT64>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
@@ -149,10 +162,97 @@ impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
) as usize
}
}
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize {
let xpm1_dft: usize = self.bytes_of_scalar_znx(1);
let xpm1: usize = self.bytes_of_scalar_znx_dft(1);
let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size);
xpm1_dft + (xpm1 | 2 * tmp)
}
}
impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
where
R: MatZnxToMut<FFT64>,
A: MatZnxToRef<FFT64>,
{
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(res.rows(), a.rows());
assert_eq!(res.cols_in(), a.cols_in());
assert_eq!(res.cols_out(), a.cols_out());
}
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
{
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
xpm1.data[0] = 1;
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
}
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size());
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size());
(0..res.rows()).for_each(|row_i| {
(0..res.cols_in()).for_each(|col_j| {
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
(0..tmp_0.cols()).for_each(|i| {
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
});
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1);
});
})
}
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
where
A: MatZnxToMut<FFT64>,
{
let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
{
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
xpm1.data[0] = 1;
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
}
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
(0..a.rows()).for_each(|row_i| {
(0..a.cols_in()).for_each(|col_j| {
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
(0..tmp_0.cols()).for_each(|i| {
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
});
self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1);
});
})
}
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
@@ -204,7 +304,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<FFT64>,
A: MatZnxToRef<FFT64>,
@@ -376,7 +476,7 @@ mod tests {
use super::{MatZnxDftAlloc, MatZnxDftScratch};
#[test]
fn vmp_prepare_row() {
fn vmp_set_row() {
let module: Module<FFT64> = Module::<FFT64>::new(16);
let basek: usize = 8;
let mat_rows: usize = 4;
@@ -395,8 +495,8 @@ mod tests {
a.fill_uniform(basek, col_out, mat_size, &mut source);
module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft);
module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in);
assert_eq!(a_dft.raw(), b_dft.raw());
}
}
@@ -413,10 +513,10 @@ mod tests {
let mat_size: usize = 6;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
[1, 2].iter().for_each(|cols_in| {
[1, 2].iter().for_each(|cols_out| {
let a_cols: usize = *cols_in;
let res_cols: usize = *cols_out;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
@@ -456,7 +556,7 @@ mod tests {
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
@@ -499,11 +599,11 @@ mod tests {
let res_size: usize = a_size;
let mut source: Source = Source::new([0u8; 32]);
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
[1, 2].iter().for_each(|cols_in| {
[1, 2].iter().for_each(|cols_out| {
(0..res_size).for_each(|shift| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let a_cols: usize = *cols_in;
let res_cols: usize = *cols_out;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
@@ -543,7 +643,7 @@ mod tests {
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
@@ -601,13 +701,13 @@ mod tests {
let mat_size: usize = 6;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
[1, 2].iter().for_each(|cols_in| {
[1, 2].iter().for_each(|cols_out| {
[1, 3, 6].iter().for_each(|digits| {
let mut source: Source = Source::new([0u8; 32]);
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let a_cols: usize = *cols_in;
let res_cols: usize = *cols_out;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
@@ -652,7 +752,7 @@ mod tests {
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, limb)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
@@ -697,4 +797,70 @@ mod tests {
});
});
}
#[test]
fn mat_znx_dft_mul_x_pow_minus_one() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 8;
let rows: usize = 2;
let cols_in: usize = 2;
let cols_out: usize = 2;
let size: usize = 4;
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
let mut source: Source = Source::new([0u8; 32]);
(0..mat_want.rows()).for_each(|row_i| {
(0..mat_want.cols_in()).for_each(|col_i| {
(0..cols_out).for_each(|j| {
tmp.fill_uniform(basek, 0, size, &mut source);
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
});
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
});
});
let k: i64 = 1;
module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow());
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
(0..mat_want.rows()).for_each(|row_i| {
(0..mat_want.cols_in()).for_each(|col_i| {
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
(0..cols_out).for_each(|j| {
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
// module.vec_znx_big_normalize(basek, &mut want, j, &tmp_big, 0, scratch.borrow());
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
});
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
(0..cols_out).for_each(|j| {
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
});
assert_eq!(have, want)
});
});
}
}

View File

@@ -53,6 +53,22 @@ pub trait VecZnxDftOps<B: Backend> {
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
@@ -150,6 +166,86 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
D: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,