mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Fixed API for scaled vmp mat add
This commit is contained in:
@@ -358,7 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
scale as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -368,8 +368,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
mod tests {
|
||||
use crate::{
|
||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
ZnxZero,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -493,13 +493,14 @@ mod tests {
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 6;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
(0..res_size).for_each(|shift| {
|
||||
let a_cols: usize = *in_cols;
|
||||
let res_cols: usize = *out_cols;
|
||||
|
||||
@@ -521,7 +522,9 @@ mod tests {
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
(0..a_size).for_each(|j| {
|
||||
a.at_mut(i, j)[i + 1] = 1 + j as i64;
|
||||
});
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
@@ -555,28 +558,42 @@ mod tests {
|
||||
module.vec_znx_dft(&mut c_dft, i, &a, 0);
|
||||
});
|
||||
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow());
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
|
||||
// Equivalent to vmp_add & scale
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
module.vec_znx_shift_inplace(
|
||||
basek,
|
||||
(shift * basek) as i64,
|
||||
&mut res_want,
|
||||
scratch.borrow(),
|
||||
);
|
||||
(0..res_cols).for_each(|i| {
|
||||
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
|
||||
});
|
||||
|
||||
res_want_vi64[1] += 1;
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64);
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +131,35 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||
let n: usize = self.n();
|
||||
let cols: usize = self.cols();
|
||||
let size: usize = self.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
self.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
self.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
let (tmp_bytes, _) = scratch.tmp_slice::<u8>(n * size_of::<i64>());
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
self.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
normalize(basek, self, i, tmp_bytes);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||
|
||||
@@ -104,6 +104,12 @@ pub trait VecZnxOps {
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Shifts by k bits all columns of `a`.
|
||||
/// A positive k applies a left shift, while a negative k applies a right shift.
|
||||
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -179,6 +185,17 @@ impl<B: Backend> VecZnxAlloc for Module<B> {
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
||||
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
if k > 0 {
|
||||
a.to_mut().lsh(basek, k as usize, scratch);
|
||||
} else {
|
||||
a.to_mut().rsh(basek, k.abs() as usize, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
|
||||
Reference in New Issue
Block a user