From ed9c94bbc81ff0f99c1a88db26e0517edbc5e661 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 6 Jun 2025 12:07:07 +0200 Subject: [PATCH] Fixed API for scaled vmp mat add --- backend/src/mat_znx_dft_ops.rs | 161 ++++++++++++++++++--------------- backend/src/vec_znx.rs | 29 ++++++ backend/src/vec_znx_ops.rs | 17 ++++ 3 files changed, 135 insertions(+), 72 deletions(-) diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index d0316a9..6af4455 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -358,7 +358,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, - scale as u64, + (scale * b.cols_out()) as u64, tmp_bytes.as_mut_ptr(), ) } @@ -368,8 +368,8 @@ impl MatZnxDftOps for Module { mod tests { use crate::{ Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, - ZnxZero, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, + ZnxViewMut, ZnxZero, }; use sampling::source::Source; @@ -493,88 +493,105 @@ mod tests { let n: usize = 1 << log_n; let module: Module = Module::::new(n); - let basek: usize = 15; - let a_size: usize = 5; + let basek: usize = 8; + let a_size: usize = 6; let mat_size: usize = 6; let res_size: usize = a_size; [1, 2].iter().for_each(|in_cols| { [1, 2].iter().for_each(|out_cols| { - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + (0..res_size).for_each(|shift| { + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = res_cols; + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; - let mut scratch: ScratchOwned = ScratchOwned::new( - module.vmp_apply_tmp_bytes( - res_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(), - ); + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); - let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); - (0..a_cols).for_each(|i| { - a.at_mut(i, a_size - 1)[i + 1] = 1; - }); - - let mut mat_znx_dft: MatZnxDft, FFT64> = - module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - - let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); - - // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. - (0..a.size()).for_each(|row_i| { - (0..mat_cols_in).for_each(|col_in_i| { - (0..mat_cols_out).for_each(|col_out_i| { - let idx = 1 + col_in_i * mat_cols_out + col_out_i; - tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); - tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; - }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); - }); - }); - - let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); - (0..a_cols).for_each(|i| { - module.vec_znx_dft(&mut a_dft, i, &a, i); - }); - - c_dft.zero(); - (0..c_dft.cols()).for_each(|i| { - module.vec_znx_dft(&mut c_dft, i, &a, 0); - }); - - module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow()); - - let mut res_have_vi64: Vec = vec![i64::default(); n]; - - let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); - }); - - (0..mat_cols_out).for_each(|col_i| { - let mut res_want_vi64: Vec = vec![i64::default(); n]; (0..a_cols).for_each(|i| { - res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + (0..a_size).for_each(|j| { + a.at_mut(i, j)[i + 1] = 1 + j as i64; + }); }); - res_want_vi64[1] += 1; + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); - assert_eq!(res_have_vi64, res_want_vi64); + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft(&mut a_dft, i, &a, i); + }); + + c_dft.zero(); + (0..c_dft.cols()).for_each(|i| { + module.vec_znx_dft(&mut c_dft, i, &a, 0); + }); + + module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow()); + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, mat_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + let mut res_want: VecZnx> = module.new_vec_znx(res_cols, mat_size); + + // Equivalent to vmp_add & scale + module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow()); + }); + module.vec_znx_shift_inplace( + basek, + (shift * basek) as i64, + &mut res_want, + scratch.borrow(), + ); + (0..res_cols).for_each(|i| { + module.vec_znx_add_inplace(&mut res_want, i, &a, 0); + }); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + let mut res_want_vi64: Vec = vec![i64::default(); n]; + + (0..mat_cols_out).for_each(|col_i| { + res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64); + res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); }); }); }); diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 35ee483..cbd1d8c 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -131,6 +131,35 @@ impl + AsRef<[u8]>> VecZnx { }) } } + + pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { + let n: usize = self.n(); + let cols: usize = self.cols(); + let size: usize = self.size(); + let steps: usize = k / basek; + + self.raw_mut().rotate_left(n * steps * cols); + (0..cols).for_each(|i| { + (size - steps..size).for_each(|j| { + self.zero_at(i, j); + }) + }); + + let k_rem: usize = k % basek; + + if k_rem != 0 { + let shift: usize = i64::BITS as usize - k_rem; + let (tmp_bytes, _) = scratch.tmp_slice::(n * size_of::()); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + self.at_mut(i, j).iter_mut().for_each(|xi| { + *xi <<= shift; + }); + }); + normalize(basek, self, i, tmp_bytes); + }); + } + } } impl>> VecZnx { diff --git a/backend/src/vec_znx_ops.rs b/backend/src/vec_znx_ops.rs index 55b1136..2bde61c 100644 --- a/backend/src/vec_znx_ops.rs +++ b/backend/src/vec_znx_ops.rs @@ -104,6 +104,12 @@ pub trait VecZnxOps { where A: VecZnxToMut; + /// Shifts by k bits all columns of `a`. + /// A positive k applies a left shift, while a negative k applies a right shift. + fn vec_znx_shift_inplace(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: VecZnxToMut; + /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -179,6 +185,17 @@ impl VecZnxAlloc for Module { } impl VecZnxOps for Module { + fn vec_znx_shift_inplace(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + if k > 0 { + a.to_mut().lsh(basek, k as usize, scratch); + } else { + a.to_mut().rsh(basek, k.abs() as usize, scratch); + } + } + fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut,