Fixed API for scaled vmp mat add

This commit is contained in:
Jean-Philippe Bossuat
2025-06-06 12:07:07 +02:00
parent 159cd8025f
commit ed9c94bbc8
3 changed files with 135 additions and 72 deletions

View File

@@ -358,7 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b.as_ptr() as *const vmp::vmp_pmat_t, b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64, (b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64, (b.size() * b.cols_out()) as u64,
scale as u64, (scale * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -368,8 +368,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
mod tests { mod tests {
use crate::{ use crate::{
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView,
ZnxZero, ZnxViewMut, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -493,13 +493,14 @@ mod tests {
let n: usize = 1 << log_n; let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 15; let basek: usize = 8;
let a_size: usize = 5; let a_size: usize = 6;
let mat_size: usize = 6; let mat_size: usize = 6;
let res_size: usize = a_size; let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| { [1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| { [1, 2].iter().for_each(|out_cols| {
(0..res_size).for_each(|shift| {
let a_cols: usize = *in_cols; let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols; let res_cols: usize = *out_cols;
@@ -521,7 +522,9 @@ mod tests {
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size); let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| { (0..a_cols).for_each(|i| {
a.at_mut(i, a_size - 1)[i + 1] = 1; (0..a_size).for_each(|j| {
a.at_mut(i, j)[i + 1] = 1 + j as i64;
});
}); });
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> = let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
@@ -555,28 +558,42 @@ mod tests {
module.vec_znx_dft(&mut c_dft, i, &a, 0); module.vec_znx_dft(&mut c_dft, i, &a, 0);
}); });
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow()); module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n]; let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
(0..mat_cols_out).for_each(|i| { (0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
}); });
(0..mat_cols_out).for_each(|col_i| { let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..a_cols).for_each(|i| { // Equivalent to vmp_add & scale
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow());
});
module.vec_znx_shift_inplace(
basek,
(shift * basek) as i64,
&mut res_want,
scratch.borrow(),
);
(0..res_cols).for_each(|i| {
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
}); });
res_want_vi64[1] += 1; let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..mat_cols_out).for_each(|col_i| {
res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64);
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64); assert_eq!(res_have_vi64, res_want_vi64);
}); });
}); });
}); });
});
} }
} }

View File

@@ -131,6 +131,35 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
}) })
} }
} }
pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
let n: usize = self.n();
let cols: usize = self.cols();
let size: usize = self.size();
let steps: usize = k / basek;
self.raw_mut().rotate_left(n * steps * cols);
(0..cols).for_each(|i| {
(size - steps..size).for_each(|j| {
self.zero_at(i, j);
})
});
let k_rem: usize = k % basek;
if k_rem != 0 {
let shift: usize = i64::BITS as usize - k_rem;
let (tmp_bytes, _) = scratch.tmp_slice::<u8>(n * size_of::<i64>());
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
self.at_mut(i, j).iter_mut().for_each(|xi| {
*xi <<= shift;
});
});
normalize(basek, self, i, tmp_bytes);
});
}
}
} }
impl<D: From<Vec<u8>>> VecZnx<D> { impl<D: From<Vec<u8>>> VecZnx<D> {

View File

@@ -104,6 +104,12 @@ pub trait VecZnxOps {
where where
A: VecZnxToMut; A: VecZnxToMut;
/// Shifts by k bits all columns of `a`.
/// A positive k applies a left shift, while a negative k applies a right shift.
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
where
A: VecZnxToMut;
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where where
@@ -179,6 +185,17 @@ impl<B: Backend> VecZnxAlloc for Module<B> {
} }
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> { impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
where
A: VecZnxToMut,
{
if k > 0 {
a.to_mut().lsh(basek, k as usize, scratch);
} else {
a.to_mut().rsh(basek, k.abs() as usize, scratch);
}
}
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where where
R: VecZnxToMut, R: VecZnxToMut,