Added API in poulpy for updated vmp_add (+tests)

This commit is contained in:
Jean-Philippe Bossuat
2025-06-04 11:39:11 +02:00
parent fcdc8f53d3
commit 159cd8025f
14 changed files with 216 additions and 82 deletions

View File

@@ -47,6 +47,7 @@ unsafe extern "C" {
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8,
);
}
@@ -79,6 +80,7 @@ unsafe extern "C" {
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8,
);
}

View File

@@ -101,7 +101,7 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
B: MatZnxToRef<FFT64>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
@@ -309,7 +309,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
@@ -358,6 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
scale as u64,
tmp_bytes.as_mut_ptr(),
)
}
@@ -368,6 +369,7 @@ mod tests {
use crate::{
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
ZnxZero,
};
use sampling::source::Source;
@@ -409,7 +411,7 @@ mod tests {
let basek: usize = 15;
let a_size: usize = 5;
let mat_size: usize = 6;
let res_size: usize = 5;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
@@ -419,7 +421,6 @@ mod tests {
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let res_cols: usize = mat_cols_out;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
@@ -435,7 +436,7 @@ mod tests {
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
a.at_mut(i, 2)[i + 1] = 1;
a.at_mut(i, a_size - 1)[i + 1] = 1;
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
@@ -479,7 +480,100 @@ mod tests {
(0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
});
res_have.decode_vec_i64(col_i, basek, basek * 3, &mut res_have_vi64);
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
});
});
});
}
#[test]
fn vmp_apply_add() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 15;
let a_size: usize = 5;
let mat_size: usize = 6;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
a.at_mut(i, a_size - 1)[i + 1] = 1;
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
(0..a.size()).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
});
c_dft.zero();
(0..c_dft.cols()).for_each(|i| {
module.vec_znx_dft(&mut c_dft, i, &a, 0);
});
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow());
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
});
(0..mat_cols_out).for_each(|col_i| {
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
});
res_want_vi64[1] += 1;
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
});
});