mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Fixed API for scaled vmp mat add
This commit is contained in:
@@ -358,7 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
(b.rows() * b.cols_in()) as u64,
|
(b.rows() * b.cols_in()) as u64,
|
||||||
(b.size() * b.cols_out()) as u64,
|
(b.size() * b.cols_out()) as u64,
|
||||||
scale as u64,
|
(scale * b.cols_out()) as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -368,8 +368,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
|
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView,
|
||||||
ZnxZero,
|
ZnxViewMut, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -493,88 +493,105 @@ mod tests {
|
|||||||
let n: usize = 1 << log_n;
|
let n: usize = 1 << log_n;
|
||||||
|
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||||
let basek: usize = 15;
|
let basek: usize = 8;
|
||||||
let a_size: usize = 5;
|
let a_size: usize = 6;
|
||||||
let mat_size: usize = 6;
|
let mat_size: usize = 6;
|
||||||
let res_size: usize = a_size;
|
let res_size: usize = a_size;
|
||||||
|
|
||||||
[1, 2].iter().for_each(|in_cols| {
|
[1, 2].iter().for_each(|in_cols| {
|
||||||
[1, 2].iter().for_each(|out_cols| {
|
[1, 2].iter().for_each(|out_cols| {
|
||||||
let a_cols: usize = *in_cols;
|
(0..res_size).for_each(|shift| {
|
||||||
let res_cols: usize = *out_cols;
|
let a_cols: usize = *in_cols;
|
||||||
|
let res_cols: usize = *out_cols;
|
||||||
|
|
||||||
let mat_rows: usize = a_size;
|
let mat_rows: usize = a_size;
|
||||||
let mat_cols_in: usize = a_cols;
|
let mat_cols_in: usize = a_cols;
|
||||||
let mat_cols_out: usize = res_cols;
|
let mat_cols_out: usize = res_cols;
|
||||||
|
|
||||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
module.vmp_apply_tmp_bytes(
|
module.vmp_apply_tmp_bytes(
|
||||||
res_size,
|
res_size,
|
||||||
a_size,
|
a_size,
|
||||||
mat_rows,
|
mat_rows,
|
||||||
mat_cols_in,
|
mat_cols_in,
|
||||||
mat_cols_out,
|
mat_cols_out,
|
||||||
mat_size,
|
mat_size,
|
||||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||||
|
|
||||||
(0..a_cols).for_each(|i| {
|
|
||||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
|
||||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
|
||||||
|
|
||||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
|
||||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
|
||||||
|
|
||||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
|
||||||
|
|
||||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
|
||||||
(0..a.size()).for_each(|row_i| {
|
|
||||||
(0..mat_cols_in).for_each(|col_in_i| {
|
|
||||||
(0..mat_cols_out).for_each(|col_out_i| {
|
|
||||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
|
||||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
|
||||||
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
|
||||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
|
||||||
});
|
|
||||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
|
||||||
(0..a_cols).for_each(|i| {
|
|
||||||
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
|
||||||
});
|
|
||||||
|
|
||||||
c_dft.zero();
|
|
||||||
(0..c_dft.cols()).for_each(|i| {
|
|
||||||
module.vec_znx_dft(&mut c_dft, i, &a, 0);
|
|
||||||
});
|
|
||||||
|
|
||||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow());
|
|
||||||
|
|
||||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
|
||||||
|
|
||||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
|
||||||
(0..mat_cols_out).for_each(|i| {
|
|
||||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
|
||||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
|
||||||
});
|
|
||||||
|
|
||||||
(0..mat_cols_out).for_each(|col_i| {
|
|
||||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
|
||||||
(0..a_cols).for_each(|i| {
|
(0..a_cols).for_each(|i| {
|
||||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
(0..a_size).for_each(|j| {
|
||||||
|
a.at_mut(i, j)[i + 1] = 1 + j as i64;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
res_want_vi64[1] += 1;
|
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||||
|
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
|
|
||||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
assert_eq!(res_have_vi64, res_want_vi64);
|
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||||
|
(0..a.size()).for_each(|row_i| {
|
||||||
|
(0..mat_cols_in).for_each(|col_in_i| {
|
||||||
|
(0..mat_cols_out).for_each(|col_out_i| {
|
||||||
|
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||||
|
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||||
|
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||||
|
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||||
|
});
|
||||||
|
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||||
|
(0..a_cols).for_each(|i| {
|
||||||
|
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
c_dft.zero();
|
||||||
|
(0..c_dft.cols()).for_each(|i| {
|
||||||
|
module.vec_znx_dft(&mut c_dft, i, &a, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
|
||||||
|
|
||||||
|
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||||
|
(0..mat_cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||||
|
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||||
|
|
||||||
|
// Equivalent to vmp_add & scale
|
||||||
|
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||||
|
(0..mat_cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||||
|
module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow());
|
||||||
|
});
|
||||||
|
module.vec_znx_shift_inplace(
|
||||||
|
basek,
|
||||||
|
(shift * basek) as i64,
|
||||||
|
&mut res_want,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
(0..res_cols).for_each(|i| {
|
||||||
|
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||||
|
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||||
|
|
||||||
|
(0..mat_cols_out).for_each(|col_i| {
|
||||||
|
res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64);
|
||||||
|
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||||
|
assert_eq!(res_have_vi64, res_want_vi64);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -131,6 +131,35 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||||
|
let n: usize = self.n();
|
||||||
|
let cols: usize = self.cols();
|
||||||
|
let size: usize = self.size();
|
||||||
|
let steps: usize = k / basek;
|
||||||
|
|
||||||
|
self.raw_mut().rotate_left(n * steps * cols);
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
(size - steps..size).for_each(|j| {
|
||||||
|
self.zero_at(i, j);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let k_rem: usize = k % basek;
|
||||||
|
|
||||||
|
if k_rem != 0 {
|
||||||
|
let shift: usize = i64::BITS as usize - k_rem;
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_slice::<u8>(n * size_of::<i64>());
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
(0..steps).for_each(|j| {
|
||||||
|
self.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||||
|
*xi <<= shift;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
normalize(basek, self, i, tmp_bytes);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||||
|
|||||||
@@ -104,6 +104,12 @@ pub trait VecZnxOps {
|
|||||||
where
|
where
|
||||||
A: VecZnxToMut;
|
A: VecZnxToMut;
|
||||||
|
|
||||||
|
/// Shifts by k bits all columns of `a`.
|
||||||
|
/// A positive k applies a left shift, while a negative k applies a right shift.
|
||||||
|
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
@@ -179,6 +185,17 @@ impl<B: Backend> VecZnxAlloc for Module<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
||||||
|
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
if k > 0 {
|
||||||
|
a.to_mut().lsh(basek, k as usize, scratch);
|
||||||
|
} else {
|
||||||
|
a.to_mut().rsh(basek, k.abs() as usize, scratch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
|
|||||||
Reference in New Issue
Block a user